diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index c872ec2b6..9991a40ab 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -192,6 +192,36 @@ class BenchmarkMetrics: p99_itl_ms: float +default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json" + + +def download_sharegpt_dataset(path): + url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" + + print(f"Downloading dataset from {url}") + try: + response = requests.get(url, stream=True) + response.raise_for_status() + + total_size = int(response.headers.get("content-length", 0)) + block_size = 8192 + + with open(path, "wb") as f, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as progress_bar: + for data in response.iter_content(block_size): + size = f.write(data) + progress_bar.update(size) + + print(f"Dataset downloaded and saved to {path}") + except requests.RequestException as e: + raise Exception(f"Failed to download dataset: {e}") + + def sample_sharegpt_requests( dataset_path: str, num_requests: int, @@ -201,36 +231,13 @@ def sample_sharegpt_requests( if fixed_output_len is not None and fixed_output_len < 4: raise ValueError("output_len too small") - default_dataset_path = "ShareGPT_V3_unfiltered_cleaned_split.json" - url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" - - if not os.path.isfile(dataset_path) and not os.path.isfile(default_dataset_path): - print(f"Downloading dataset from {url}") - try: - response = requests.get(url, stream=True) - response.raise_for_status() - - total_size = int(response.headers.get("content-length", 0)) - block_size = 8192 - - with open(default_dataset_path, "wb") as f, tqdm( - desc="Downloading", - total=total_size, - unit="iB", - unit_scale=True, - unit_divisor=1024, - ) as progress_bar: - for data in response.iter_content(block_size): - size = f.write(data) - progress_bar.update(size) - - print(f"Dataset downloaded and saved to {default_dataset_path}") - dataset_path = default_dataset_path - except requests.RequestException as e: - raise Exception(f"Failed to download dataset: {e}") + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path else: dataset_path = ( - dataset_path if os.path.isfile(dataset_path) else default_dataset_path + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path ) # Load the dataset. @@ -279,6 +286,7 @@ def sample_random_requests( num_prompts: int, range_ratio: float, tokenizer: PreTrainedTokenizerBase, + dataset_path: str, ) -> List[Tuple[str, int, int]]: input_lens = np.random.randint( @@ -291,13 +299,62 @@ def sample_random_requests( output_len + 1, size=num_prompts, ) - offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) - input_requests = [] - for i in range(num_prompts): - prompt = tokenizer.decode( - [(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])] - ) - input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + + if True: + # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens + + # Download sharegpt if necessary + if not os.path.isfile(dataset_path) and not os.path.isfile( + default_sharegpt_path + ): + download_sharegpt_dataset(default_sharegpt_path) + dataset_path = default_sharegpt_path + else: + dataset_path = ( + dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path + ) + + # Load the dataset. + with open(dataset_path) as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] + + # Shuffle the dataset. + random.shuffle(dataset) + + # Filter out sequences that are too long or too short + input_requests: List[Tuple[str, int, int]] = [] + for i in range(num_prompts): + # Tokenize the prompts and completions. + prompt = dataset[i][0] + prompt_token_ids = tokenizer(prompt).input_ids + prompt_len = len(prompt_token_ids) + + if prompt_len <= input_lens[i]: + input_ids = prompt_token_ids[: input_lens[i]] + else: + ratio = (input_lens[i] + prompt_len - 1) // prompt_len + input_ids = (prompt_token_ids * ratio)[: input_lens[i]] + prompt = tokenizer.decode(input_ids) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) + else: + # Sample token ids from random integers. This can cause some NaN issues. + offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) + input_requests = [] + for i in range(num_prompts): + prompt = tokenizer.decode( + [ + (offsets[i] + i + j) % tokenizer.vocab_size + for j in range(input_lens[i]) + ] + ) + input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) print(f"#Input tokens: {np.sum(input_lens)}") print(f"#Output tokens: {np.sum(output_lens)}") @@ -575,6 +632,7 @@ def fire(args: argparse.Namespace): num_prompts=args.num_prompts, range_ratio=args.random_range_ratio, tokenizer=tokenizer, + dataset_path=args.dataset_path, ) else: raise ValueError(f"Unknown dataset: {args.dataset_name}")