Add disk cache for loading ShareGPT dataset. (#542)

This commit is contained in:
Liangsheng Yin
2024-06-13 16:37:12 +08:00
committed by GitHub
parent fb9296f0ed
commit 40e53d65cb

View File

@@ -19,6 +19,7 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import json import json
import os
import random import random
import time import time
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
@@ -37,43 +38,62 @@ def sample_requests(
num_requests: int, num_requests: int,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Tokenize the prompts and completions. def load_dataset():
prompts = [prompt for prompt, _ in dataset] with open(dataset_path, encoding="utf-8") as f:
prompt_token_ids = tokenizer(prompts).input_ids dataset = json.load(f)
completions = [completion for _, completion in dataset] # Filter out the conversations with less than 2 turns.
completion_token_ids = tokenizer(completions).input_ids dataset = [data for data in dataset if len(data["conversations"]) >= 2]
tokenized_dataset = [] # Only keep the first two turns of each conversation.
for i in range(len(dataset)): dataset = [
output_len = len(completion_token_ids[i]) (data["conversations"][0]["value"], data["conversations"][1]["value"])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) for data in dataset
]
# Filter out too long sequences. # Tokenize the prompts and completions.
filtered_dataset: List[Tuple[str, int, int]] = [] prompts = [prompt for prompt, _ in dataset]
for prompt, prompt_token_ids, output_len in tokenized_dataset: prompt_token_ids = tokenizer(prompts).input_ids
prompt_len = len(prompt_token_ids) completions = [completion for _, completion in dataset]
if prompt_len < 4 or output_len < 4: completion_token_ids = tokenizer(completions).input_ids
# Prune too short sequences. tokenized_dataset = []
# This is because TGI causes errors when the input or output length for i in range(len(dataset)):
# is too short. output_len = len(completion_token_ids[i])
continue tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences. # Filter out too long sequences.
continue filtered_dataset: List[Tuple[str, int, int]] = []
filtered_dataset.append((prompt, prompt_len, output_len)) for prompt, prompt_token_ids, output_len in tokenized_dataset:
prompt_len = len(prompt_token_ids)
if prompt_len < 4 or output_len < 4:
# Prune too short sequences.
# This is because TGI causes errors when the input or output length
# is too short.
continue
if prompt_len > 1024 or prompt_len + output_len > 2048:
# Prune too long sequences.
continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
try:
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache = Cache(f"{home_dir}/.cache/sglang")
with Cache(cache.directory) as reference:
reference_key = f"{dataset_path}_{tokenizer.name_or_path}"
if reference_key in reference:
print("Reading dataset from cache...")
dataset = reference[reference_key]
else:
dataset = load_dataset()
reference[reference_key] = dataset
except ImportError:
dataset = load_dataset()
# Sample the requests. # Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests) sampled_requests = random.sample(dataset, num_requests)
return sampled_requests return sampled_requests