[router] cache-aware load-balancing router v1 (#2114)

This commit is contained in:
Byron Hsu
2024-11-23 08:34:48 -08:00
committed by GitHub
parent ad47749b82
commit cbedd1db1d
17 changed files with 1963 additions and 602 deletions

View File

@@ -1,21 +1,24 @@
import itertools
import json
import os
import random
import string
import threading
import time
from argparse import ArgumentParser
from pathlib import Path
from typing import Union
from tqdm import tqdm
import sglang as sgl
from sglang.srt.hf_transformers_utils import get_tokenize
from sglang.srt.hf_transformers_utils import get_tokenizer
from sglang.test.test_utils import (
add_common_sglang_args_and_parse,
select_sglang_backend,
)
from sglang.utils import dump_state_text
random.seed(42)
def gen_prompt(tokenizer, token_num):
all_available_tokens = list(tokenizer.get_vocab().values())
@@ -24,12 +27,34 @@ def gen_prompt(tokenizer, token_num):
return ret
def get_cache_path(args):
# Create cache directory under ~/.cache/sglang
cache_dir = Path.home() / ".cache" / "sglang"
# Create a unique cache filename based on the arguments that affect generation
cache_key = f"qa_{args.num_qa}_{args.turns}_{args.system_prompt_len}_{args.len_q}_{args.len_a}_{args.tokenizer.replace('/', '_')}.json"
return cache_dir / cache_key
def gen_arguments(args, tokenizer):
multi_qas = [
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
for _ in range(args.num_qa)
]
for i in range(args.num_qa):
cache_path = get_cache_path(args)
# Try to load from cache first
if cache_path.exists():
print(f"Loading cached arguments from {cache_path}")
with open(cache_path, "r") as f:
return json.load(f)
print("Generating new arguments...")
# First progress bar for system prompts
multi_qas = []
for _ in tqdm(range(args.num_qa), desc="Generating system prompts"):
multi_qas.append(
{"system_prompt": gen_prompt(tokenizer, args.system_prompt_len), "qas": []}
)
# Nested progress bars for QA pairs
for i in tqdm(range(args.num_qa), desc="Generating QA pairs"):
qas = multi_qas[i]["qas"]
for j in range(args.turns):
qas.append(
@@ -38,6 +63,13 @@ def gen_arguments(args, tokenizer):
"new_tokens": args.len_a,
}
)
# Save to cache
cache_path.parent.mkdir(parents=True, exist_ok=True)
with open(cache_path, "w") as f:
json.dump(multi_qas, f)
print(f"Cached arguments saved to {cache_path}")
return multi_qas
@@ -45,7 +77,7 @@ def gen_arguments(args, tokenizer):
def multi_turns(s, system_prompt, qas):
s += system_prompt
for qa in qas:
for i, qa in enumerate(qas):
s += qa["prompt"]
s += sgl.gen(max_tokens=qa["new_tokens"], ignore_eos=True)
@@ -62,7 +94,7 @@ def main(args):
multi_qas,
temperature=0,
backend=backend,
num_threads=args.parallel,
num_threads="auto",
progress_bar=True,
)
latency = time.time() - tic
@@ -75,7 +107,6 @@ def main(args):
value = {
"task": "multi_turn_system_prompt_chat",
"backend": args.backend,
"num_gpus": 1,
"latency": round(latency, 3),
"num_requests": args.num_qa,
"num_turns": args.turns,