From 40e53d65cbb8b609a6ff8e977d2318044d0f0ee0 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Thu, 13 Jun 2024 16:37:12 +0800 Subject: [PATCH] Add disk cache for loading ShareGPT dataset. (#542) --- .../latency_throughput/bench_throughput.py | 86 ++++++++++++------- 1 file changed, 53 insertions(+), 33 deletions(-) diff --git a/benchmark/latency_throughput/bench_throughput.py b/benchmark/latency_throughput/bench_throughput.py index 286f1fb12..003a6bc80 100644 --- a/benchmark/latency_throughput/bench_throughput.py +++ b/benchmark/latency_throughput/bench_throughput.py @@ -19,6 +19,7 @@ On the client side, run: import argparse import asyncio import json +import os import random import time from typing import AsyncGenerator, List, Tuple @@ -37,43 +38,62 @@ def sample_requests( num_requests: int, tokenizer: AutoTokenizer, ) -> List[Tuple[str, int, int]]: - # Load the dataset. - with open(dataset_path) as f: - dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [ - (data["conversations"][0]["value"], data["conversations"][1]["value"]) - for data in dataset - ] - # Tokenize the prompts and completions. - prompts = [prompt for prompt, _ in dataset] - prompt_token_ids = tokenizer(prompts).input_ids - completions = [completion for _, completion in dataset] - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): - output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + def load_dataset(): + with open(dataset_path, encoding="utf-8") as f: + dataset = json.load(f) + # Filter out the conversations with less than 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= 2] + # Only keep the first two turns of each conversation. + dataset = [ + (data["conversations"][0]["value"], data["conversations"][1]["value"]) + for data in dataset + ] - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) - if prompt_len < 4 or output_len < 4: - # Prune too short sequences. - # This is because TGI causes errors when the input or output length - # is too short. - continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. - continue - filtered_dataset.append((prompt, prompt_len, output_len)) + # Tokenize the prompts and completions. + prompts = [prompt for prompt, _ in dataset] + prompt_token_ids = tokenizer(prompts).input_ids + completions = [completion for _, completion in dataset] + completion_token_ids = tokenizer(completions).input_ids + tokenized_dataset = [] + for i in range(len(dataset)): + output_len = len(completion_token_ids[i]) + tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + + # Filter out too long sequences. + filtered_dataset: List[Tuple[str, int, int]] = [] + for prompt, prompt_token_ids, output_len in tokenized_dataset: + prompt_len = len(prompt_token_ids) + if prompt_len < 4 or output_len < 4: + # Prune too short sequences. + # This is because TGI causes errors when the input or output length + # is too short. + continue + if prompt_len > 1024 or prompt_len + output_len > 2048: + # Prune too long sequences. + continue + filtered_dataset.append((prompt, prompt_len, output_len)) + + return filtered_dataset + + try: + from diskcache import Cache + + home_dir = os.path.expanduser("~") + cache = Cache(f"{home_dir}/.cache/sglang") + with Cache(cache.directory) as reference: + reference_key = f"{dataset_path}_{tokenizer.name_or_path}" + if reference_key in reference: + print("Reading dataset from cache...") + dataset = reference[reference_key] + else: + dataset = load_dataset() + reference[reference_key] = dataset + except ImportError: + dataset = load_dataset() # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) + sampled_requests = random.sample(dataset, num_requests) return sampled_requests