Improve benchmark (#1140)
This commit is contained in:
@@ -149,10 +149,12 @@ async def async_request_openai_completions(
|
||||
"completions"
|
||||
), "OpenAI Completions API URL must end with 'completions'."
|
||||
|
||||
prompt = request_func_input.prompt
|
||||
|
||||
async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
|
||||
payload = {
|
||||
"model": request_func_input.model,
|
||||
"prompt": request_func_input.prompt,
|
||||
"prompt": prompt,
|
||||
"temperature": 0.0,
|
||||
"best_of": 1,
|
||||
"max_tokens": request_func_input.output_len,
|
||||
@@ -220,6 +222,13 @@ async def async_request_openai_completions(
|
||||
return output
|
||||
|
||||
|
||||
async def async_request_ginfer(
|
||||
request_func_input: RequestFuncInput,
|
||||
pbar: Optional[tqdm] = None,
|
||||
) -> RequestFuncOutput:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
|
||||
import huggingface_hub.constants
|
||||
@@ -238,6 +247,13 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
||||
def get_tokenizer(
|
||||
pretrained_model_name_or_path: str,
|
||||
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
||||
if pretrained_model_name_or_path.endswith(
|
||||
".json"
|
||||
) or pretrained_model_name_or_path.endswith(".model"):
|
||||
from sglang.srt.hf_transformers_utils import get_tokenizer
|
||||
|
||||
return get_tokenizer(pretrained_model_name_or_path)
|
||||
|
||||
if pretrained_model_name_or_path is not None and not os.path.exists(
|
||||
pretrained_model_name_or_path
|
||||
):
|
||||
@@ -252,6 +268,7 @@ ASYNC_REQUEST_FUNCS = {
|
||||
"vllm": async_request_openai_completions,
|
||||
"lmdeploy": async_request_openai_completions,
|
||||
"trt": async_request_trt_llm,
|
||||
"ginfer": async_request_ginfer,
|
||||
}
|
||||
|
||||
|
||||
@@ -351,9 +368,9 @@ def sample_sharegpt_requests(
|
||||
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
completion = dataset[i][1]
|
||||
completion_token_ids = tokenizer(completion).input_ids
|
||||
completion_token_ids = tokenizer.encode(completion)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
output_len = (
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
@@ -361,7 +378,9 @@ def sample_sharegpt_requests(
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or prompt_len + output_len > 2048:
|
||||
if prompt_len > 1024 or (
|
||||
prompt_len + output_len > 2048 and fixed_output_len is None
|
||||
):
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
@@ -422,7 +441,7 @@ def sample_random_requests(
|
||||
for i in range(num_prompts):
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
if prompt_len > input_lens[i]:
|
||||
@@ -488,7 +507,7 @@ def calculate_metrics(
|
||||
output_len = outputs[i].output_len
|
||||
output_lens.append(output_len)
|
||||
retokenized_output_len = len(
|
||||
tokenizer(outputs[i].generated_text, add_special_tokens=False).input_ids
|
||||
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
||||
)
|
||||
retokenized_output_lens.append(retokenized_output_len)
|
||||
total_input += input_requests[i][1]
|
||||
@@ -547,7 +566,6 @@ async def benchmark(
|
||||
input_requests: List[Tuple[str, int, int]],
|
||||
request_rate: float,
|
||||
disable_tqdm: bool,
|
||||
enable_multi: bool,
|
||||
extra_request_body: Dict[str, Any],
|
||||
):
|
||||
if backend in ASYNC_REQUEST_FUNCS:
|
||||
@@ -756,6 +774,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
global args
|
||||
args = args_
|
||||
|
||||
# Set global environments
|
||||
set_ulimit()
|
||||
random.seed(args.seed)
|
||||
np.random.seed(args.seed)
|
||||
@@ -764,12 +783,14 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.extra_request_body:
|
||||
extra_request_body = json.loads(args.extra_request_body)
|
||||
|
||||
# Set url
|
||||
if args.port is None:
|
||||
args.port = {
|
||||
"sglang": 30000,
|
||||
"lmdeploy": 23333,
|
||||
"vllm": 8000,
|
||||
"trt": 8000,
|
||||
"ginfer": 9988,
|
||||
}.get(args.backend, 30000)
|
||||
|
||||
api_url = (
|
||||
@@ -792,7 +813,11 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
if args.model is None:
|
||||
print("Please provide a model using `--model` when using `trt` backend.")
|
||||
sys.exit(1)
|
||||
elif args.backend == "ginfer":
|
||||
api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
|
||||
args.model = args.model or "default"
|
||||
|
||||
# Get model name
|
||||
if args.model is None:
|
||||
try:
|
||||
response = requests.get(model_url)
|
||||
@@ -817,6 +842,7 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
|
||||
print(f"{args}\n")
|
||||
|
||||
# Read dataset
|
||||
backend = args.backend
|
||||
model_id = args.model
|
||||
tokenizer_id = args.tokenizer if args.tokenizer is not None else args.model
|
||||
@@ -842,7 +868,21 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
if args.multi:
|
||||
if not args.multi:
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Benchmark multiple rps. TODO: use a fixed duration to compute num_prompts
|
||||
request_rates = parse_request_rate_range(args.request_rate_range)
|
||||
|
||||
for rate in request_rates:
|
||||
@@ -855,27 +895,11 @@ def run_benchmark(args_: argparse.Namespace):
|
||||
input_requests=input_requests,
|
||||
request_rate=rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
else:
|
||||
return asyncio.run(
|
||||
benchmark(
|
||||
backend=backend,
|
||||
api_url=api_url,
|
||||
model_id=model_id,
|
||||
tokenizer=tokenizer,
|
||||
input_requests=input_requests,
|
||||
request_rate=args.request_rate,
|
||||
disable_tqdm=args.disable_tqdm,
|
||||
enable_multi=args.multi,
|
||||
extra_request_body=extra_request_body,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# to avoid relying on SGLang's components
|
||||
def set_ulimit(target_soft_limit=65535):
|
||||
resource_type = resource.RLIMIT_NOFILE
|
||||
current_soft, current_hard = resource.getrlimit(resource_type)
|
||||
@@ -968,7 +992,7 @@ if __name__ == "__main__":
|
||||
help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
|
||||
"Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
|
||||
parser.add_argument("--seed", type=int, default=1, help="The random seed.")
|
||||
parser.add_argument(
|
||||
"--multi",
|
||||
action="store_true",
|
||||
|
||||
@@ -30,7 +30,17 @@ from transformers import (
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||
|
||||
try:
|
||||
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
||||
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ChatGLMConfig.model_type: ChatGLMConfig,
|
||||
DbrxConfig.model_type: DbrxConfig,
|
||||
}
|
||||
except ImportError:
|
||||
# We want this file to run without vllm dependency
|
||||
_CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {}
|
||||
|
||||
from sglang.srt.utils import is_multimodal_model
|
||||
|
||||
|
||||
@@ -113,30 +113,7 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
|
||||
|
||||
def call_generate_ginfer(prompt, temperature, max_tokens, stop=None, url=None):
|
||||
import grpc
|
||||
from ginfer import sampler_pb2, sampler_pb2_grpc
|
||||
|
||||
sampler_channel = grpc.insecure_channel(url.replace("http://", ""))
|
||||
sampler = sampler_pb2_grpc.SamplerStub(sampler_channel)
|
||||
|
||||
if stop is None:
|
||||
stop_strings = None
|
||||
else:
|
||||
stop_strings = [stop]
|
||||
|
||||
sample_request = sampler_pb2.SampleTextRequest(
|
||||
prompt=prompt,
|
||||
settings=sampler_pb2.SampleSettings(
|
||||
max_len=max_tokens,
|
||||
rng_seed=0,
|
||||
temperature=max(temperature, 1e-7),
|
||||
nucleus_p=1,
|
||||
stop_strings=stop_strings,
|
||||
),
|
||||
)
|
||||
stream = sampler.SampleText(sample_request)
|
||||
response = "".join([x.text for x in stream])
|
||||
return response
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def call_generate_guidance(
|
||||
|
||||
Reference in New Issue
Block a user