[Fix] Fix select by ensuring each request has at least one token (#1318)

This commit is contained in:
Lianmin Zheng
2024-09-03 06:31:45 -07:00
committed by GitHub
parent 12cb115d38
commit 1e495e0847
4 changed files with 120 additions and 3 deletions

View File

@@ -4,6 +4,7 @@ import base64
import importlib
import json
import logging
import os
import signal
import sys
import traceback
@@ -15,6 +16,7 @@ from typing import Union
import numpy as np
import requests
from tqdm import tqdm
logger = logging.getLogger(__name__)
@@ -260,3 +262,40 @@ class LazyImport:
def __call__(self, *args, **kwargs):
module = self._load()
return module(*args, **kwargs)
def fetch_and_cache_jsonl(url, cache_file="cached_data.jsonl"):
"""Read and cache a jsonl file from a url."""
# Check if the cache file already exists
if os.path.exists(cache_file):
print("Loading data from cache...")
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
else:
print("Downloading data from URL...")
# Stream the response to show the progress bar
response = requests.get(url, stream=True)
response.raise_for_status() # Check for request errors
# Total size of the file in bytes
total_size = int(response.headers.get("content-length", 0))
chunk_size = 1024 # Download in chunks of 1KB
# Use tqdm to display the progress bar
with open(cache_file, "wb") as f, tqdm(
desc=cache_file,
total=total_size,
unit="B",
unit_scale=True,
unit_divisor=1024,
) as bar:
for chunk in response.iter_content(chunk_size=chunk_size):
f.write(chunk)
bar.update(len(chunk))
# Convert the data to a list of dictionaries
with open(cache_file, "r") as f:
data = [json.loads(line) for line in f]
return data