Improve benchmark scripts and error message printing (#2922)
This commit is contained in:
@@ -39,14 +39,15 @@ class BenchArgs:
|
||||
dataset_path: str = ""
|
||||
num_prompts: int = 1000
|
||||
sharegpt_output_len: Optional[int] = None
|
||||
sharegpt_context_len: Optional[int] = None
|
||||
random_input_len: int = 1024
|
||||
random_output_len: int = 1024
|
||||
random_range_ratio: float = 0.0
|
||||
gen_num_groups: int = 64
|
||||
gen_prompts_per_group: int = 16
|
||||
gen_system_prompt_len: int = 2048
|
||||
gen_question_len: int = 128
|
||||
gen_output_len: int = 256
|
||||
gsp_num_groups: int = 64
|
||||
gsp_prompts_per_group: int = 16
|
||||
gsp_system_prompt_len: int = 2048
|
||||
gsp_question_len: int = 128
|
||||
gsp_output_len: int = 256
|
||||
disable_ignore_eos: bool = False
|
||||
extra_request_body: Optional[str] = None
|
||||
seed: int = 1
|
||||
@@ -82,6 +83,12 @@ class BenchArgs:
|
||||
default=BenchArgs.sharegpt_output_len,
|
||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-context-len",
|
||||
type=int,
|
||||
default=BenchArgs.sharegpt_context_len,
|
||||
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
@@ -102,35 +109,35 @@ class BenchArgs:
|
||||
"used only for random dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-num-groups",
|
||||
"--gsp-num-groups",
|
||||
type=int,
|
||||
default=BenchArgs.gen_num_groups,
|
||||
default=BenchArgs.gsp_num_groups,
|
||||
help="Number of groups with shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-prompts-per-group",
|
||||
"--gsp-prompts-per-group",
|
||||
type=int,
|
||||
default=BenchArgs.gen_prompts_per_group,
|
||||
default=BenchArgs.gsp_prompts_per_group,
|
||||
help="Number of prompts per group of shared prefix, used"
|
||||
"only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-system-prompt-len",
|
||||
"--gsp-system-prompt-len",
|
||||
type=int,
|
||||
default=BenchArgs.gen_system_prompt_len,
|
||||
default=BenchArgs.gsp_system_prompt_len,
|
||||
help="System prompt length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-question-len",
|
||||
"--gsp-question-len",
|
||||
type=int,
|
||||
default=BenchArgs.gen_question_len,
|
||||
default=BenchArgs.gsp_question_len,
|
||||
help="Question length, used" "only for generate-shared-prefix",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--gen-output-len",
|
||||
"--gsp-output-len",
|
||||
type=int,
|
||||
default=BenchArgs.gen_output_len,
|
||||
default=BenchArgs.gsp_output_len,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -452,6 +452,7 @@ def get_dataset(args, tokenizer):
|
||||
num_requests=args.num_prompts,
|
||||
tokenizer=tokenizer,
|
||||
fixed_output_len=args.sharegpt_output_len,
|
||||
context_len=args.sharegpt_context_len,
|
||||
)
|
||||
elif args.dataset_name == "random":
|
||||
input_requests = sample_random_requests(
|
||||
@@ -464,11 +465,11 @@ def get_dataset(args, tokenizer):
|
||||
)
|
||||
elif args.dataset_name == "generated-shared-prefix":
|
||||
input_requests = sample_generated_shared_prefix_requests(
|
||||
num_groups=args.gen_num_groups,
|
||||
prompts_per_group=args.gen_prompts_per_group,
|
||||
system_prompt_len=args.gen_system_prompt_len,
|
||||
question_len=args.gen_question_len,
|
||||
output_len=args.gen_output_len,
|
||||
num_groups=args.gsp_num_groups,
|
||||
prompts_per_group=args.gsp_prompts_per_group,
|
||||
system_prompt_len=args.gsp_system_prompt_len,
|
||||
question_len=args.gsp_question_len,
|
||||
output_len=args.gsp_output_len,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
else:
|
||||
@@ -560,6 +561,7 @@ def sample_sharegpt_requests(
|
||||
num_requests: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
fixed_output_len: Optional[int] = None,
|
||||
context_len: Optional[int] = None,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
@@ -597,14 +599,15 @@ def sample_sharegpt_requests(
|
||||
output_len = (
|
||||
len(completion_token_ids) if fixed_output_len is None else fixed_output_len
|
||||
)
|
||||
if prompt_len < 4 or output_len < 4:
|
||||
|
||||
if prompt_len < 1 or output_len < 1:
|
||||
# Prune too short sequences.
|
||||
continue
|
||||
if prompt_len > 1024 or (
|
||||
prompt_len + output_len > 2048 and fixed_output_len is None
|
||||
):
|
||||
|
||||
if context_len and prompt_len + output_len > context_len:
|
||||
# Prune too long sequences.
|
||||
continue
|
||||
|
||||
filtered_dataset.append((prompt, prompt_len, output_len))
|
||||
|
||||
print(f"#Input tokens: {np.sum([x[1] for x in filtered_dataset])}")
|
||||
@@ -706,8 +709,8 @@ def get_gen_prefix_cache_path(args, tokenizer):
|
||||
|
||||
# Create a unique cache filename based on the generation parameters
|
||||
cache_key = (
|
||||
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
||||
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
||||
f"gen_shared_prefix_{args.gsp_num_groups}_{args.gsp_prompts_per_group}_"
|
||||
f"{args.gsp_system_prompt_len}_{args.gsp_question_len}_{args.gsp_output_len}_"
|
||||
f"{tokenizer.__class__.__name__}.pkl"
|
||||
)
|
||||
return cache_dir / cache_key
|
||||
@@ -1374,6 +1377,12 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="Output length for each request. Overrides the output length from the ShareGPT dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--sharegpt-context-len",
|
||||
type=int,
|
||||
default=None,
|
||||
help="The context length of the model for the ShareGPT dataset. Requests longer than the context length will be dropped.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--random-input-len",
|
||||
type=int,
|
||||
@@ -1453,38 +1462,6 @@ if __name__ == "__main__":
|
||||
help="Append given JSON object to the request payload. You can use this to specify"
|
||||
"additional generate params like sampling params.",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
"--gen-num-groups",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of system prompt groups for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-prompts-per-group",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-system-prompt-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-question-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gen-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
@@ -1497,5 +1474,37 @@ if __name__ == "__main__":
|
||||
default=None,
|
||||
help="The name of LoRA adapter",
|
||||
)
|
||||
|
||||
group = parser.add_argument_group("generated-shared-prefix dataset arguments")
|
||||
group.add_argument(
|
||||
"--gsp-num-groups",
|
||||
type=int,
|
||||
default=64,
|
||||
help="Number of system prompt groups for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-prompts-per-group",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of prompts per system prompt group for generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-system-prompt-len",
|
||||
type=int,
|
||||
default=2048,
|
||||
help="Target length in tokens for system prompts in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-question-len",
|
||||
type=int,
|
||||
default=128,
|
||||
help="Target length in tokens for questions in generated-shared-prefix dataset",
|
||||
)
|
||||
group.add_argument(
|
||||
"--gsp-output-len",
|
||||
type=int,
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
run_benchmark(args)
|
||||
|
||||
@@ -59,6 +59,9 @@ class GenerateReqInput:
|
||||
return_text_in_logprobs: bool = False
|
||||
# Whether to stream output.
|
||||
stream: bool = False
|
||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
log_metrics: bool = True
|
||||
|
||||
# The modalities of the image data [image, multi-images, video]
|
||||
modalities: Optional[List[str]] = None
|
||||
# LoRA related
|
||||
@@ -196,6 +199,7 @@ class GenerateReqInput:
|
||||
top_logprobs_num=self.top_logprobs_num[i],
|
||||
return_text_in_logprobs=self.return_text_in_logprobs,
|
||||
stream=self.stream,
|
||||
log_metrics=self.log_metrics,
|
||||
modalities=self.modalities[i] if self.modalities else None,
|
||||
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
||||
)
|
||||
@@ -243,6 +247,8 @@ class EmbeddingReqInput:
|
||||
sampling_params: Union[List[Dict], Dict] = None
|
||||
# Dummy input embeds for compatibility
|
||||
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
||||
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
||||
log_metrics: bool = True
|
||||
|
||||
def normalize_batch_and_arguments(self):
|
||||
if (self.text is None and self.input_ids is None) or (
|
||||
|
||||
@@ -631,7 +631,8 @@ class Scheduler:
|
||||
if len(req.origin_input_ids) > self.max_req_input_len:
|
||||
logger.warning(
|
||||
"Request length is longer than the KV cache pool size or "
|
||||
"the max context length. Truncated!!!"
|
||||
"the max context length. Truncated. "
|
||||
f"{len(req.origin_input_ids)=}, {self.max_req_input_len=}."
|
||||
)
|
||||
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ from sglang.srt.utils import (
|
||||
get_zmq_socket,
|
||||
kill_process_tree,
|
||||
)
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
||||
|
||||
@@ -640,7 +641,9 @@ class TokenizerManager:
|
||||
|
||||
self.to_create_loop = False
|
||||
loop = asyncio.get_event_loop()
|
||||
self.asyncio_tasks.add(loop.create_task(self.handle_loop()))
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.handle_loop))
|
||||
)
|
||||
|
||||
# We cannot add signal handler when the tokenizer manager is not in
|
||||
# the main thread due to the CPython limitation.
|
||||
@@ -653,7 +656,9 @@ class TokenizerManager:
|
||||
"not in the main thread. This disables graceful shutdown of the "
|
||||
"tokenizer manager when SIGTERM is received."
|
||||
)
|
||||
self.asyncio_tasks.add(loop.create_task(self.sigterm_watchdog()))
|
||||
self.asyncio_tasks.add(
|
||||
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
||||
)
|
||||
|
||||
async def sigterm_watchdog(self):
|
||||
while not self.gracefully_exit:
|
||||
@@ -738,9 +743,13 @@ class TokenizerManager:
|
||||
state.finished = recv_obj.finished_reasons[i] is not None
|
||||
state.event.set()
|
||||
|
||||
if self.enable_metrics:
|
||||
if self.enable_metrics and state.obj.log_metrics:
|
||||
self.collect_metrics(state, recv_obj, i)
|
||||
if self.dump_requests_folder and state.finished:
|
||||
if (
|
||||
self.dump_requests_folder
|
||||
and state.finished
|
||||
and state.obj.log_metrics
|
||||
):
|
||||
self.dump_requests(state, out_dict)
|
||||
elif isinstance(recv_obj, OpenSessionReqOutput):
|
||||
self.session_futures[recv_obj.session_id].set_result(
|
||||
@@ -887,20 +896,38 @@ class TokenizerManager:
|
||||
)
|
||||
|
||||
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
||||
filename = os.path.join(
|
||||
self.dump_requests_folder,
|
||||
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
||||
)
|
||||
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
||||
|
||||
to_dump = self.dump_request_list
|
||||
self.dump_request_list = []
|
||||
|
||||
def background_task():
|
||||
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
||||
current_time = datetime.now()
|
||||
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
||||
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
||||
with open(filename, "wb") as f:
|
||||
pickle.dump(to_dump, f)
|
||||
|
||||
# Schedule the task to run in the background without awaiting it
|
||||
asyncio.create_task(asyncio.to_thread(background_task))
|
||||
|
||||
|
||||
async def print_exception_wrapper(func):
|
||||
"""
|
||||
Sometimes an asyncio function does not print exception.
|
||||
We do another wrapper to handle the exception.
|
||||
"""
|
||||
try:
|
||||
await func()
|
||||
except Exception:
|
||||
traceback = get_exception_traceback()
|
||||
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
||||
kill_process_tree(os.getpid(), include_parent=True)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
class SignalHandler:
|
||||
def __init__(self, tokenizer_manager):
|
||||
self.tokenizer_manager = tokenizer_manager
|
||||
|
||||
@@ -135,9 +135,13 @@ async def health_generate(request: Request) -> Response:
|
||||
sampling_params = {"max_new_tokens": 1, "temperature": 0.7}
|
||||
|
||||
if tokenizer_manager.is_generation:
|
||||
gri = GenerateReqInput(input_ids=[0], sampling_params=sampling_params)
|
||||
gri = GenerateReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
else:
|
||||
gri = EmbeddingReqInput(input_ids=[0], sampling_params=sampling_params)
|
||||
gri = EmbeddingReqInput(
|
||||
input_ids=[0], sampling_params=sampling_params, log_metrics=False
|
||||
)
|
||||
|
||||
try:
|
||||
async for _ in tokenizer_manager.generate_request(gri, request):
|
||||
|
||||
@@ -560,6 +560,7 @@ def run_bench_serving(
|
||||
tokenizer=tokenizer,
|
||||
num_prompts=num_prompts,
|
||||
sharegpt_output_len=None,
|
||||
sharegpt_context_len=None,
|
||||
random_input_len=random_input_len,
|
||||
random_output_len=random_output_len,
|
||||
random_range_ratio=0.0,
|
||||
|
||||
Reference in New Issue
Block a user