[FEAT] Support batches cancel (#1222)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
index, True
|
||||
), f"index {index} is not found in the response"
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
def _create_batch(self, mode, client):
|
||||
if mode == "completion":
|
||||
input_file_path = "complete_input.jsonl"
|
||||
# write content to input file
|
||||
@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
with open(input_file_path, "w") as file:
|
||||
for line in content:
|
||||
file.write(json.dumps(line) + "\n")
|
||||
|
||||
with open(input_file_path, "rb") as file:
|
||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||
if mode == "completion":
|
||||
@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
endpoint=endpoint,
|
||||
completion_window=completion_window,
|
||||
)
|
||||
|
||||
return batch_job, content
|
||||
|
||||
def run_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, content = self._create_batch(mode=mode, client=client)
|
||||
|
||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||
time.sleep(3)
|
||||
print(
|
||||
@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
]
|
||||
assert len(results) == len(content)
|
||||
|
||||
def run_cancel_batch(self, mode):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
batch_job, _ = self._create_batch(mode=mode, client=client)
|
||||
|
||||
assert batch_job.status not in ["cancelling", "cancelled"]
|
||||
|
||||
batch_job = client.batches.cancel(batch_id=batch_job.id)
|
||||
assert batch_job.status == "cancelling"
|
||||
|
||||
while batch_job.status not in ["failed", "cancelled"]:
|
||||
batch_job = client.batches.retrieve(batch_job.id)
|
||||
print(
|
||||
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||
)
|
||||
time.sleep(3)
|
||||
|
||||
assert batch_job.status == "cancelled"
|
||||
|
||||
def test_completion(self):
|
||||
for echo in [False, True]:
|
||||
for logprobs in [None, 5]:
|
||||
@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_batch(mode)
|
||||
|
||||
def test_calcel_batch(self):
|
||||
for mode in ["completion", "chat"]:
|
||||
self.run_cancel_batch(mode)
|
||||
|
||||
def test_regex(self):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user