Fix random dataset (#671)
This commit is contained in:
@@ -192,6 +192,36 @@ class BenchmarkMetrics:
|
||||
p99_itl_ms: float
|
||||
|
||||
|
||||
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
|
||||
|
||||
def download_sharegpt_dataset(path):
|
||||
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
|
||||
print(f"Downloading dataset from {url}")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
|
||||
with open(path, "wb") as f, tqdm(
|
||||
desc="Downloading",
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress_bar:
|
||||
for data in response.iter_content(block_size):
|
||||
size = f.write(data)
|
||||
progress_bar.update(size)
|
||||
|
||||
print(f"Dataset downloaded and saved to {path}")
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to download dataset: {e}")
|
||||
|
||||
|
||||
def sample_sharegpt_requests(
|
||||
dataset_path: str,
|
||||
num_requests: int,
|
||||
@@ -201,36 +231,13 @@ def sample_sharegpt_requests(
|
||||
if fixed_output_len is not None and fixed_output_len < 4:
|
||||
raise ValueError("output_len too small")
|
||||
|
||||
default_dataset_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
|
||||
|
||||
if not os.path.isfile(dataset_path) and not os.path.isfile(default_dataset_path):
|
||||
print(f"Downloading dataset from {url}")
|
||||
try:
|
||||
response = requests.get(url, stream=True)
|
||||
response.raise_for_status()
|
||||
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
block_size = 8192
|
||||
|
||||
with open(default_dataset_path, "wb") as f, tqdm(
|
||||
desc="Downloading",
|
||||
total=total_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1024,
|
||||
) as progress_bar:
|
||||
for data in response.iter_content(block_size):
|
||||
size = f.write(data)
|
||||
progress_bar.update(size)
|
||||
|
||||
print(f"Dataset downloaded and saved to {default_dataset_path}")
|
||||
dataset_path = default_dataset_path
|
||||
except requests.RequestException as e:
|
||||
raise Exception(f"Failed to download dataset: {e}")
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
|
||||
download_sharegpt_dataset(default_sharegpt_path)
|
||||
dataset_path = default_sharegpt_path
|
||||
else:
|
||||
dataset_path = (
|
||||
dataset_path if os.path.isfile(dataset_path) else default_dataset_path
|
||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
||||
)
|
||||
|
||||
# Load the dataset.
|
||||
@@ -279,6 +286,7 @@ def sample_random_requests(
|
||||
num_prompts: int,
|
||||
range_ratio: float,
|
||||
tokenizer: PreTrainedTokenizerBase,
|
||||
dataset_path: str,
|
||||
) -> List[Tuple[str, int, int]]:
|
||||
|
||||
input_lens = np.random.randint(
|
||||
@@ -291,13 +299,62 @@ def sample_random_requests(
|
||||
output_len + 1,
|
||||
size=num_prompts,
|
||||
)
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(
|
||||
[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]
|
||||
)
|
||||
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
||||
|
||||
if True:
|
||||
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
||||
|
||||
# Download sharegpt if necessary
|
||||
if not os.path.isfile(dataset_path) and not os.path.isfile(
|
||||
default_sharegpt_path
|
||||
):
|
||||
download_sharegpt_dataset(default_sharegpt_path)
|
||||
dataset_path = default_sharegpt_path
|
||||
else:
|
||||
dataset_path = (
|
||||
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
|
||||
)
|
||||
|
||||
# Load the dataset.
|
||||
with open(dataset_path) as f:
|
||||
dataset = json.load(f)
|
||||
# Filter out the conversations with less than 2 turns.
|
||||
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
|
||||
# Only keep the first two turns of each conversation.
|
||||
dataset = [
|
||||
(data["conversations"][0]["value"], data["conversations"][1]["value"])
|
||||
for data in dataset
|
||||
]
|
||||
|
||||
# Shuffle the dataset.
|
||||
random.shuffle(dataset)
|
||||
|
||||
# Filter out sequences that are too long or too short
|
||||
input_requests: List[Tuple[str, int, int]] = []
|
||||
for i in range(num_prompts):
|
||||
# Tokenize the prompts and completions.
|
||||
prompt = dataset[i][0]
|
||||
prompt_token_ids = tokenizer(prompt).input_ids
|
||||
prompt_len = len(prompt_token_ids)
|
||||
|
||||
if prompt_len <= input_lens[i]:
|
||||
input_ids = prompt_token_ids[: input_lens[i]]
|
||||
else:
|
||||
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
||||
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
||||
prompt = tokenizer.decode(input_ids)
|
||||
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
||||
else:
|
||||
# Sample token ids from random integers. This can cause some NaN issues.
|
||||
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
||||
input_requests = []
|
||||
for i in range(num_prompts):
|
||||
prompt = tokenizer.decode(
|
||||
[
|
||||
(offsets[i] + i + j) % tokenizer.vocab_size
|
||||
for j in range(input_lens[i])
|
||||
]
|
||||
)
|
||||
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
||||
|
||||
print(f"#Input tokens: {np.sum(input_lens)}")
|
||||
print(f"#Output tokens: {np.sum(output_lens)}")
|
||||
@@ -575,6 +632,7 @@ def fire(args: argparse.Namespace):
|
||||
num_prompts=args.num_prompts,
|
||||
range_ratio=args.random_range_ratio,
|
||||
tokenizer=tokenizer,
|
||||
dataset_path=args.dataset_path,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown dataset: {args.dataset_name}")
|
||||
|
||||
Reference in New Issue
Block a user