Replace prob based with threshold based load balancing (#2170)
This commit is contained in:
@@ -25,6 +25,7 @@ import warnings
|
||||
from argparse import ArgumentParser
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
@@ -693,6 +694,19 @@ def gen_prompt(tokenizer, token_num):
|
||||
return tokenizer.decode(selected_tokens)
|
||||
|
||||
|
||||
def get_gen_prefix_cache_path(args, tokenizer):
|
||||
"""Create cache directory under ~/.cache/sglang/benchmark"""
|
||||
cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
|
||||
|
||||
# Create a unique cache filename based on the generation parameters
|
||||
cache_key = (
|
||||
f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
|
||||
f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
|
||||
f"{tokenizer.__class__.__name__}.pkl"
|
||||
)
|
||||
return cache_dir / cache_key
|
||||
|
||||
|
||||
def sample_generated_shared_prefix_requests(
|
||||
num_groups: int,
|
||||
prompts_per_group: int,
|
||||
@@ -701,12 +715,17 @@ def sample_generated_shared_prefix_requests(
|
||||
output_len: int,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
if args.generated_input_path and os.path.exists(args.generated_input_path):
|
||||
print(f"\nloading generated input data from {args.generated_input_path}")
|
||||
with open(args.generated_input_path, "rb") as f:
|
||||
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
||||
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
||||
|
||||
# Try to load from cache first
|
||||
if cache_path.exists():
|
||||
print(f"\nLoading cached generated input data from {cache_path}")
|
||||
with open(cache_path, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
"""Generate benchmark requests with shared system prompts using random tokens."""
|
||||
print("\nGenerating new input data...")
|
||||
|
||||
# Generate system prompts for each group
|
||||
system_prompts = []
|
||||
for _ in range(num_groups):
|
||||
@@ -719,9 +738,6 @@ def sample_generated_shared_prefix_requests(
|
||||
question = gen_prompt(tokenizer, question_len)
|
||||
questions.append(question)
|
||||
|
||||
# Shuffle questions
|
||||
random.shuffle(questions)
|
||||
|
||||
# Combine system prompts with questions
|
||||
input_requests = []
|
||||
total_input_tokens = 0
|
||||
@@ -729,7 +745,9 @@ def sample_generated_shared_prefix_requests(
|
||||
|
||||
for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
|
||||
system_prompt = system_prompts[group_idx]
|
||||
for prompt_idx in tqdm(range(prompts_per_group), desc="Generating questions"):
|
||||
for prompt_idx in tqdm(
|
||||
range(prompts_per_group), desc="Generating questions", leave=False
|
||||
):
|
||||
question = questions[group_idx * prompts_per_group + prompt_idx]
|
||||
full_prompt = f"{system_prompt}\n\n{question}"
|
||||
prompt_len = len(tokenizer.encode(full_prompt))
|
||||
@@ -738,6 +756,10 @@ def sample_generated_shared_prefix_requests(
|
||||
total_input_tokens += prompt_len
|
||||
total_output_tokens += output_len
|
||||
|
||||
# Shuffle questions
|
||||
random.shuffle(input_requests)
|
||||
|
||||
# Print statistics
|
||||
print(f"\nGenerated shared prefix dataset statistics:")
|
||||
print(f"Number of groups: {num_groups}")
|
||||
print(f"Prompts per group: {prompts_per_group}")
|
||||
@@ -750,11 +772,12 @@ def sample_generated_shared_prefix_requests(
|
||||
print(
|
||||
f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
|
||||
)
|
||||
if args.generated_input_save_path:
|
||||
print(f"Saving generated input data to {args.generated_input_save_path}")
|
||||
os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
|
||||
with open(args.generated_input_save_path, "wb") as f:
|
||||
pickle.dump(input_requests, f)
|
||||
|
||||
# Save to cache
|
||||
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
print(f"Caching generated input data to {cache_path}")
|
||||
with open(cache_path, "wb") as f:
|
||||
pickle.dump(input_requests, f)
|
||||
|
||||
return input_requests
|
||||
|
||||
@@ -1422,16 +1445,6 @@ if __name__ == "__main__":
|
||||
default=256,
|
||||
help="Target length in tokens for outputs in generated-shared-prefix dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated-input-save-path",
|
||||
type=str,
|
||||
help="Path to save generated input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--generated-input-path",
|
||||
type=str,
|
||||
help="Path to load previously generated input data",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--profile",
|
||||
action="store_true",
|
||||
|
||||
Reference in New Issue
Block a user