Add disk cache for loading ShareGPT dataset. (#542)

This commit is contained in:
Liangsheng Yin
2024-06-13 16:37:12 +08:00
committed by GitHub
parent fb9296f0ed
commit 40e53d65cb

View File

@@ -19,6 +19,7 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import json import json
import os
import random import random
import time import time
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
@@ -37,8 +38,9 @@ def sample_requests(
num_requests: int, num_requests: int,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f: def load_dataset():
with open(dataset_path, encoding="utf-8") as f:
dataset = json.load(f) dataset = json.load(f)
# Filter out the conversations with less than 2 turns. # Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2] dataset = [data for data in dataset if len(data["conversations"]) >= 2]
@@ -72,8 +74,26 @@ def sample_requests(
continue continue
filtered_dataset.append((prompt, prompt_len, output_len)) filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
try:
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache = Cache(f"{home_dir}/.cache/sglang")
with Cache(cache.directory) as reference:
reference_key = f"{dataset_path}_{tokenizer.name_or_path}"
if reference_key in reference:
print("Reading dataset from cache...")
dataset = reference[reference_key]
else:
dataset = load_dataset()
reference[reference_key] = dataset
except ImportError:
dataset = load_dataset()
# Sample the requests. # Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests) sampled_requests = random.sample(dataset, num_requests)
return sampled_requests return sampled_requests