Cleanup readme, llava examples, usage examples and nccl init (#1194)
This commit is contained in:
45
examples/runtime/async_io_api.py
Normal file
45
examples/runtime/async_io_api.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Usage:
|
||||
python3 async_io.py
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
|
||||
from sglang import Runtime
|
||||
|
||||
|
||||
async def generate(
|
||||
engine,
|
||||
prompt,
|
||||
sampling_params,
|
||||
):
|
||||
tokenizer = engine.get_tokenizer()
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "system",
|
||||
"content": "You will be given question answer tasks.",
|
||||
},
|
||||
{"role": "user", "content": prompt},
|
||||
]
|
||||
|
||||
prompt = tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
|
||||
stream = engine.add_request(prompt, sampling_params)
|
||||
|
||||
async for output in stream:
|
||||
print(output, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime = Runtime(model_path="meta-llama/Llama-2-7b-chat-hf")
|
||||
print("--- runtime ready ---\n")
|
||||
|
||||
prompt = "Who is Alan Turing?"
|
||||
sampling_params = {"max_new_tokens": 128}
|
||||
asyncio.run(generate(runtime, prompt, sampling_params))
|
||||
|
||||
runtime.shutdown()
|
||||
111
examples/runtime/llava_onevision/http_llama3_llava_test.py
Normal file
111
examples/runtime/llava_onevision/http_llama3_llava_test.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Usage:
|
||||
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||
# Installing latest sglang.
|
||||
|
||||
# Endpoint Service CLI:
|
||||
python -m sglang.launch_server --model-path lmms-lab/llama3-llava-next-8b --port=30000
|
||||
|
||||
python3 http_llama3_llava_test.py
|
||||
|
||||
Output:
|
||||
"Friends posing for a fun photo with a life-sized teddy bear, creating a playful and memorable moment."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from llava.conversation import conv_llava_llama_3
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
response = []
|
||||
for i in range(1):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": prompt_with_template,
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|eot_id|>",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_llava_llama_3)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
pload = {
|
||||
"text": prompt_with_template,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|eot_id|>",
|
||||
},
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"stream": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=pload,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(test_concurrent(args))
|
||||
test_streaming(args)
|
||||
218
examples/runtime/llava_onevision/http_llava_onevision_test.py
Normal file
218
examples/runtime/llava_onevision/http_llava_onevision_test.py
Normal file
@@ -0,0 +1,218 @@
|
||||
"""
|
||||
Usage:
|
||||
|
||||
python3 -m sglang.launch_server --model-path lmms-lab/llava-onevision-qwen2-72b-ov --port=30000 --tp-size=8 --chat-template=chatml-llava --chunked-prefill-size=16384
|
||||
|
||||
python3 http_llava_onevision_test.py
|
||||
"""
|
||||
|
||||
import base64
|
||||
import io
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import openai
|
||||
import requests
|
||||
from decord import VideoReader, cpu
|
||||
from PIL import Image
|
||||
|
||||
# pip install httpx==0.23.3
|
||||
# pip install decord
|
||||
# pip install protobuf==3.20.0
|
||||
|
||||
|
||||
def download_video(url, cache_dir):
|
||||
file_path = os.path.join(cache_dir, "jobs.mp4")
|
||||
os.makedirs(cache_dir, exist_ok=True)
|
||||
|
||||
response = requests.get(url)
|
||||
response.raise_for_status()
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
print(f"File downloaded and saved to: {file_path}")
|
||||
return file_path
|
||||
|
||||
|
||||
def create_openai_client(base_url):
|
||||
return openai.Client(api_key="EMPTY", base_url=base_url)
|
||||
|
||||
|
||||
def image_stream_request_test(client):
|
||||
print("----------------------Image Stream Request Test----------------------")
|
||||
stream_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe this image. Please list the benchmarks and the models.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0.7,
|
||||
max_tokens=1024,
|
||||
stream=True,
|
||||
)
|
||||
stream_response = ""
|
||||
|
||||
for chunk in stream_request:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
stream_response += content
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def video_stream_request_test(client, video_path):
|
||||
print("------------------------Video Stream Request Test----------------------")
|
||||
messages = prepare_video_messages(video_path)
|
||||
|
||||
video_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
stream=True,
|
||||
)
|
||||
print("-" * 30)
|
||||
video_response = ""
|
||||
|
||||
for chunk in video_request:
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
content = chunk.choices[0].delta.content
|
||||
video_response += content
|
||||
sys.stdout.write(content)
|
||||
sys.stdout.flush()
|
||||
print("-" * 30)
|
||||
|
||||
|
||||
def image_speed_test(client):
|
||||
print("----------------------Image Speed Test----------------------")
|
||||
start_time = time.time()
|
||||
request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=[
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": "https://raw.githubusercontent.com/sgl-project/sglang/main/assets/logo.png"
|
||||
},
|
||||
},
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Please describe this image. Please list the benchmarks and the models.",
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
end_time = time.time()
|
||||
response = request.choices[0].message.content
|
||||
print(response)
|
||||
print("-" * 30)
|
||||
print_speed_test_results(request, start_time, end_time)
|
||||
|
||||
|
||||
def video_speed_test(client, video_path):
|
||||
print("------------------------Video Speed Test------------------------")
|
||||
messages = prepare_video_messages(video_path)
|
||||
|
||||
start_time = time.time()
|
||||
video_request = client.chat.completions.create(
|
||||
model="default",
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
max_tokens=1024,
|
||||
)
|
||||
end_time = time.time()
|
||||
video_response = video_request.choices[0].message.content
|
||||
print(video_response)
|
||||
print("-" * 30)
|
||||
print_speed_test_results(video_request, start_time, end_time)
|
||||
|
||||
|
||||
def prepare_video_messages(video_path):
|
||||
max_frames_num = 32
|
||||
vr = VideoReader(video_path, ctx=cpu(0))
|
||||
total_frame_num = len(vr)
|
||||
uniform_sampled_frames = np.linspace(
|
||||
0, total_frame_num - 1, max_frames_num, dtype=int
|
||||
)
|
||||
frame_idx = uniform_sampled_frames.tolist()
|
||||
frames = vr.get_batch(frame_idx).asnumpy()
|
||||
|
||||
base64_frames = []
|
||||
for frame in frames:
|
||||
pil_img = Image.fromarray(frame)
|
||||
buff = io.BytesIO()
|
||||
pil_img.save(buff, format="JPEG")
|
||||
base64_str = base64.b64encode(buff.getvalue()).decode("utf-8")
|
||||
base64_frames.append(base64_str)
|
||||
|
||||
messages = [{"role": "user", "content": []}]
|
||||
frame_format = {
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/jpeg;base64,{}"},
|
||||
}
|
||||
|
||||
for base64_frame in base64_frames:
|
||||
frame_format["image_url"]["url"] = "data:image/jpeg;base64,{}".format(
|
||||
base64_frame
|
||||
)
|
||||
messages[0]["content"].append(frame_format.copy())
|
||||
|
||||
prompt = {"type": "text", "text": "Please describe the video in detail."}
|
||||
messages[0]["content"].append(prompt)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def print_speed_test_results(request, start_time, end_time):
|
||||
total_tokens = request.usage.total_tokens
|
||||
completion_tokens = request.usage.completion_tokens
|
||||
prompt_tokens = request.usage.prompt_tokens
|
||||
|
||||
print(f"Total tokens: {total_tokens}")
|
||||
print(f"Completion tokens: {completion_tokens}")
|
||||
print(f"Prompt tokens: {prompt_tokens}")
|
||||
print(f"Time taken: {end_time - start_time} seconds")
|
||||
print(f"Token per second: {total_tokens / (end_time - start_time)}")
|
||||
print(f"Completion token per second: {completion_tokens / (end_time - start_time)}")
|
||||
print(f"Prompt token per second: {prompt_tokens / (end_time - start_time)}")
|
||||
|
||||
|
||||
def main():
|
||||
url = "https://raw.githubusercontent.com/EvolvingLMMs-Lab/sglang/dev/onevision_local/assets/jobs.mp4"
|
||||
cache_dir = os.path.expanduser("~/.cache")
|
||||
video_path = download_video(url, cache_dir)
|
||||
|
||||
client = create_openai_client("http://127.0.0.1:30000/v1")
|
||||
|
||||
image_stream_request_test(client)
|
||||
video_stream_request_test(client, video_path)
|
||||
image_speed_test(client)
|
||||
video_speed_test(client, video_path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
111
examples/runtime/llava_onevision/http_qwen_llava_test.py
Normal file
111
examples/runtime/llava_onevision/http_qwen_llava_test.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
Usage:
|
||||
# Installing latest llava-next: pip install git+https://github.com/LLaVA-VL/LLaVA-NeXT.git
|
||||
# Installing latest sglang.
|
||||
|
||||
# Endpoint Service CLI:
|
||||
python -m sglang.launch_server --model-path lmms-lab/llava-next-72b --port=30000 --tp-size=8
|
||||
|
||||
python3 http_qwen_llava_test.py
|
||||
|
||||
Output:
|
||||
"Two children pose with a large teddy bear, one holding a smaller stuffed bear, in a room with an American flag and potted plants."
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import copy
|
||||
import json
|
||||
|
||||
import aiohttp
|
||||
import requests
|
||||
from llava.conversation import conv_qwen
|
||||
|
||||
|
||||
async def send_request(url, data, delay=0):
|
||||
await asyncio.sleep(delay)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(url, json=data) as resp:
|
||||
output = await resp.json()
|
||||
return output
|
||||
|
||||
|
||||
async def test_concurrent(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_qwen)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
response = []
|
||||
for i in range(1):
|
||||
response.append(
|
||||
send_request(
|
||||
url + "/generate",
|
||||
{
|
||||
"text": prompt_with_template,
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|im_end|>",
|
||||
},
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
rets = await asyncio.gather(*response)
|
||||
for ret in rets:
|
||||
print(ret["text"])
|
||||
|
||||
|
||||
def test_streaming(args):
|
||||
url = f"{args.host}:{args.port}"
|
||||
prompt = "<image>\nPlease generate caption towards this image."
|
||||
conv_template = copy.deepcopy(conv_qwen)
|
||||
conv_template.append_message(role=conv_template.roles[0], message=prompt)
|
||||
conv_template.append_message(role=conv_template.roles[1], message=None)
|
||||
prompt_with_template = conv_template.get_prompt()
|
||||
pload = {
|
||||
"text": prompt_with_template,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1024,
|
||||
"temperature": 0,
|
||||
"top_p": 1.0,
|
||||
"presence_penalty": 2,
|
||||
"frequency_penalty": 2,
|
||||
"stop": "<|im_end|>",
|
||||
},
|
||||
"image_data": "https://farm4.staticflickr.com/3175/2653711032_804ff86d81_z.jpg",
|
||||
"stream": True,
|
||||
}
|
||||
response = requests.post(
|
||||
url + "/generate",
|
||||
json=pload,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
prev = 0
|
||||
for chunk in response.iter_lines(decode_unicode=False):
|
||||
chunk = chunk.decode("utf-8")
|
||||
if chunk and chunk.startswith("data:"):
|
||||
if chunk == "data: [DONE]":
|
||||
break
|
||||
data = json.loads(chunk[5:].strip("\n"))
|
||||
output = data["text"].strip()
|
||||
print(output[prev:], end="", flush=True)
|
||||
prev = len(output)
|
||||
print("")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
||||
parser.add_argument("--port", type=int, default=30000)
|
||||
args = parser.parse_args()
|
||||
asyncio.run(test_concurrent(args))
|
||||
test_streaming(args)
|
||||
96
examples/runtime/openai_batch_chat.py
Normal file
96
examples/runtime/openai_batch_chat.py
Normal file
@@ -0,0 +1,96 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
python openai_batch_chat.py
|
||||
Note: Before running this script,
|
||||
you should create the input.jsonl file with the following content:
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world! List 3 NBA players and tell a story"}],"max_tokens": 300}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-3.5-turbo-0125", "messages": [{"role": "system", "content": "You are an assistant. "},{"role": "user", "content": "Hello world! List three capital and tell a story"}],"max_tokens": 500}}
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class OpenAIBatchProcessor:
|
||||
def __init__(self, api_key):
|
||||
# client = OpenAI(api_key=api_key)
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
self.client = client
|
||||
|
||||
def process_batch(self, input_file_path, endpoint, completion_window):
|
||||
|
||||
# Upload the input file
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = self.client.files.create(file=file, purpose="batch")
|
||||
|
||||
# Create the batch job
|
||||
batch_job = self.client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
|
||||
# Monitor the batch job status
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3) # Wait for 3 seconds before checking the status again
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = self.client.batches.retrieve(batch_job.id)
|
||||
|
||||
# Check the batch job status and errors
|
||||
if batch_job.status == "failed":
|
||||
print(f"Batch job failed with status: {batch_job.status}")
|
||||
print(f"Batch job errors: {batch_job.errors}")
|
||||
return None
|
||||
|
||||
# If the batch job is completed, process the results
|
||||
if batch_job.status == "completed":
|
||||
|
||||
# print result of batch job
|
||||
print("batch", batch_job.request_counts)
|
||||
|
||||
result_file_id = batch_job.output_file_id
|
||||
# Retrieve the file content from the server
|
||||
file_response = self.client.files.content(result_file_id)
|
||||
result_content = file_response.read() # Read the content of the file
|
||||
|
||||
# Save the content to a local file
|
||||
result_file_name = "batch_job_chat_results.jsonl"
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(result_content) # Write the binary content to the file
|
||||
# Load data from the saved JSONL file
|
||||
results = []
|
||||
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
json_object = json.loads(
|
||||
line.strip()
|
||||
) # Parse each line as a JSON object
|
||||
results.append(json_object)
|
||||
|
||||
return results
|
||||
else:
|
||||
print(f"Batch job failed with status: {batch_job.status}")
|
||||
return None
|
||||
|
||||
|
||||
# Initialize the OpenAIBatchProcessor
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
processor = OpenAIBatchProcessor(api_key)
|
||||
|
||||
# Process the batch job
|
||||
input_file_path = "input.jsonl"
|
||||
endpoint = "/v1/chat/completions"
|
||||
completion_window = "24h"
|
||||
|
||||
# Process the batch job
|
||||
results = processor.process_batch(input_file_path, endpoint, completion_window)
|
||||
|
||||
# Print the results
|
||||
print(results)
|
||||
97
examples/runtime/openai_batch_complete.py
Normal file
97
examples/runtime/openai_batch_complete.py
Normal file
@@ -0,0 +1,97 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000
|
||||
python openai_batch_complete.py
|
||||
Note: Before running this script,
|
||||
you should create the input.jsonl file with the following content:
|
||||
{"custom_id": "request-1", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 3 names of famous soccer player: ", "max_tokens": 200}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
|
||||
{"custom_id": "request-3", "method": "POST", "url": "/v1/completions", "body": {"model": "gpt-3.5-turbo-instruct", "prompt": "List 6 names of famous basketball player: ", "max_tokens": 400}}
|
||||
"""
|
||||
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
|
||||
import openai
|
||||
from openai import OpenAI
|
||||
|
||||
|
||||
class OpenAIBatchProcessor:
|
||||
def __init__(self, api_key):
|
||||
# client = OpenAI(api_key=api_key)
|
||||
client = openai.Client(base_url="http://127.0.0.1:30000/v1", api_key="EMPTY")
|
||||
|
||||
self.client = client
|
||||
|
||||
def process_batch(self, input_file_path, endpoint, completion_window):
|
||||
|
||||
# Upload the input file
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = self.client.files.create(file=file, purpose="batch")
|
||||
|
||||
# Create the batch job
|
||||
batch_job = self.client.batches.create(
|
||||
input_file_id=uploaded_file.id,
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
|
||||
# Monitor the batch job status
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3) # Wait for 3 seconds before checking the status again
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
batch_job = self.client.batches.retrieve(batch_job.id)
|
||||
|
||||
# Check the batch job status and errors
|
||||
if batch_job.status == "failed":
|
||||
print(f"Batch job failed with status: {batch_job.status}")
|
||||
print(f"Batch job errors: {batch_job.errors}")
|
||||
return None
|
||||
|
||||
# If the batch job is completed, process the results
|
||||
if batch_job.status == "completed":
|
||||
|
||||
# print result of batch job
|
||||
print("batch", batch_job.request_counts)
|
||||
|
||||
result_file_id = batch_job.output_file_id
|
||||
# Retrieve the file content from the server
|
||||
file_response = self.client.files.content(result_file_id)
|
||||
result_content = file_response.read() # Read the content of the file
|
||||
|
||||
# Save the content to a local file
|
||||
result_file_name = "batch_job_complete_results.jsonl"
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(result_content) # Write the binary content to the file
|
||||
# Load data from the saved JSONL file
|
||||
results = []
|
||||
with open(result_file_name, "r", encoding="utf-8") as file:
|
||||
for line in file:
|
||||
json_object = json.loads(
|
||||
line.strip()
|
||||
) # Parse each line as a JSON object
|
||||
results.append(json_object)
|
||||
|
||||
return results
|
||||
else:
|
||||
print(f"Batch job failed with status: {batch_job.status}")
|
||||
return None
|
||||
|
||||
|
||||
# Initialize the OpenAIBatchProcessor
|
||||
api_key = os.environ.get("OPENAI_API_KEY")
|
||||
processor = OpenAIBatchProcessor(api_key)
|
||||
|
||||
# Process the batch job
|
||||
input_file_path = "input_complete.jsonl"
|
||||
endpoint = "/v1/completions"
|
||||
completion_window = "24h"
|
||||
|
||||
# Process the batch job
|
||||
results = processor.process_batch(input_file_path, endpoint, completion_window)
|
||||
|
||||
# Print the results
|
||||
print(results)
|
||||
Reference in New Issue
Block a user