Replace prob based with threshold based load balancing (#2170)

This commit is contained in:
Byron Hsu
2024-11-24 23:17:11 -08:00
committed by GitHub
parent 8e1adb8441
commit 4b0a1c9365
7 changed files with 223 additions and 151 deletions

View File

@@ -25,6 +25,7 @@ import warnings
from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
import aiohttp
@@ -693,6 +694,19 @@ def gen_prompt(tokenizer, token_num):
return tokenizer.decode(selected_tokens)
def get_gen_prefix_cache_path(args, tokenizer):
"""Create cache directory under ~/.cache/sglang/benchmark"""
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
# Create a unique cache filename based on the generation parameters
cache_key = (
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
f"{tokenizer.__class__.__name__}.pkl"
)
return cache_dir / cache_key
def sample_generated_shared_prefix_requests(
num_groups: int,
prompts_per_group: int,
@@ -701,12 +715,17 @@ def sample_generated_shared_prefix_requests(
output_len: int,
tokenizer: PreTrainedTokenizerBase,
) -> List[Tuple[str, int, int]]:
if args.generated_input_path and os.path.exists(args.generated_input_path):
print(f"\nloading generated input data from {args.generated_input_path}")
with open(args.generated_input_path, "rb") as f:
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
cache_path = get_gen_prefix_cache_path(args, tokenizer)
# Try to load from cache first
if cache_path.exists():
print(f"\nLoading cached generated input data from {cache_path}")
with open(cache_path, "rb") as f:
return pickle.load(f)
"""Generate benchmark requests with shared system prompts using random tokens."""
print("\nGenerating new input data...")
# Generate system prompts for each group
system_prompts = []
for _ in range(num_groups):
@@ -719,9 +738,6 @@ def sample_generated_shared_prefix_requests(
question = gen_prompt(tokenizer, question_len)
questions.append(question)
# Shuffle questions
random.shuffle(questions)
# Combine system prompts with questions
input_requests = []
total_input_tokens = 0
@@ -729,7 +745,9 @@ def sample_generated_shared_prefix_requests(
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
system_prompt = system_prompts[group_idx]
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
for prompt_idx in tqdm(
range(prompts_per_group), desc="Generating questions", leave=False
):
question = questions[group_idx * prompts_per_group + prompt_idx]
full_prompt = f"{system_prompt}\n\n{question}"
prompt_len = len(tokenizer.encode(full_prompt))
@@ -738,6 +756,10 @@ def sample_generated_shared_prefix_requests(
total_input_tokens += prompt_len
total_output_tokens += output_len
# Shuffle questions
random.shuffle(input_requests)
# Print statistics
print(f"\nGenerated shared prefix dataset statistics:")
print(f"Number of groups: {num_groups}")
print(f"Prompts per group: {prompts_per_group}")
@@ -750,11 +772,12 @@ def sample_generated_shared_prefix_requests(
print(
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
)
if args.generated_input_save_path:
print(f"Saving generated input data to {args.generated_input_save_path}")
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
with open(args.generated_input_save_path, "wb") as f:
pickle.dump(input_requests, f)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
print(f"Caching generated input data to {cache_path}")
with open(cache_path, "wb") as f:
pickle.dump(input_requests, f)
return input_requests
@@ -1422,16 +1445,6 @@ if __name__ == "__main__":
default=256,
help="Target length in tokens for outputs in generated-shared-prefix dataset",
)
parser.add_argument(
"--generated-input-save-path",
type=str,
help="Path to save generated input data",
)
parser.add_argument(
"--generated-input-path",
type=str,
help="Path to load previously generated input data",
)
parser.add_argument(
"--profile",
action="store_true",