[FEAT] Support batches cancel (#1222)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
caiyueliang
2024-08-27 07:28:26 +08:00
committed by GitHub
parent c61a1b6f97
commit 2f1d92834f
3 changed files with 122 additions and 6 deletions

View File

@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
index, True
), f"index {index} is not found in the response"
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
def _create_batch(self, mode, client):
if mode == "completion":
input_file_path = "complete_input.jsonl"
# write content to input file
@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
},
},
]
with open(input_file_path, "w") as file:
for line in content:
file.write(json.dumps(line) + "\n")
with open(input_file_path, "rb") as file:
uploaded_file = client.files.create(file=file, purpose="batch")
if mode == "completion":
@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
endpoint=endpoint,
completion_window=completion_window,
)
return batch_job, content
def run_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, content = self._create_batch(mode=mode, client=client)
while batch_job.status not in ["completed", "failed", "cancelled"]:
time.sleep(3)
print(
@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
]
assert len(results) == len(content)
def run_cancel_batch(self, mode):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
batch_job, _ = self._create_batch(mode=mode, client=client)
assert batch_job.status not in ["cancelling", "cancelled"]
batch_job = client.batches.cancel(batch_id=batch_job.id)
assert batch_job.status == "cancelling"
while batch_job.status not in ["failed", "cancelled"]:
batch_job = client.batches.retrieve(batch_job.id)
print(
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
)
time.sleep(3)
assert batch_job.status == "cancelled"
def test_completion(self):
for echo in [False, True]:
for logprobs in [None, 5]:
@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
for mode in ["completion", "chat"]:
self.run_batch(mode)
def test_calcel_batch(self):
for mode in ["completion", "chat"]:
self.run_cancel_batch(mode)
def test_regex(self):
client = openai.Client(api_key=self.api_key, base_url=self.base_url)