[FEAT] Support batches cancel (#1222)
Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
end_point = batch_storage[batch_id].endpoint
|
end_point = batch_storage[batch_id].endpoint
|
||||||
file_request_list = []
|
file_request_list = []
|
||||||
all_requests = []
|
all_requests = []
|
||||||
|
request_ids = []
|
||||||
for line in lines:
|
for line in lines:
|
||||||
request_data = json.loads(line)
|
request_data = json.loads(line)
|
||||||
file_request_list.append(request_data)
|
file_request_list.append(request_data)
|
||||||
body = request_data["body"]
|
body = request_data["body"]
|
||||||
|
request_ids.append(request_data["custom_id"])
|
||||||
|
|
||||||
# Although streaming is supported for standalone completions, it is not supported in
|
# Although streaming is supported for standalone completions, it is not supported in
|
||||||
# batch mode (multiple completions in single request).
|
# batch mode (multiple completions in single request).
|
||||||
@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
all_requests.append(ChatCompletionRequest(**body))
|
all_requests.append(ChatCompletionRequest(**body))
|
||||||
elif end_point == "/v1/completions":
|
elif end_point == "/v1/completions":
|
||||||
all_requests.append(CompletionRequest(**body))
|
all_requests.append(CompletionRequest(**body))
|
||||||
|
|
||||||
if end_point == "/v1/chat/completions":
|
if end_point == "/v1/chat/completions":
|
||||||
adapted_request, request = v1_chat_generate_request(
|
adapted_request, request = v1_chat_generate_request(
|
||||||
all_requests, tokenizer_manager
|
all_requests, tokenizer_manager, request_ids=request_ids
|
||||||
)
|
)
|
||||||
elif end_point == "/v1/completions":
|
elif end_point == "/v1/completions":
|
||||||
adapted_request, request = v1_generate_request(all_requests)
|
adapted_request, request = v1_generate_request(
|
||||||
|
all_requests, request_ids=request_ids
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
|
||||||
if not isinstance(ret, list):
|
if not isinstance(ret, list):
|
||||||
@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
|||||||
}
|
}
|
||||||
all_ret.append(response_json)
|
all_ret.append(response_json)
|
||||||
completed_requests += 1
|
completed_requests += 1
|
||||||
|
|
||||||
# Write results to a new file
|
# Write results to a new file
|
||||||
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
output_file_id = f"backend_result_file-{uuid.uuid4()}"
|
||||||
global storage_dir
|
global storage_dir
|
||||||
@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
|
|||||||
return batch_response
|
return batch_response
|
||||||
|
|
||||||
|
|
||||||
|
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
|
||||||
|
# Retrieve the batch job from the in-memory storage
|
||||||
|
batch_response = batch_storage.get(batch_id)
|
||||||
|
if batch_response is None:
|
||||||
|
raise HTTPException(status_code=404, detail="Batch not found")
|
||||||
|
|
||||||
|
# Only do cancal when status is "validating" or "in_progress"
|
||||||
|
if batch_response.status in ["validating", "in_progress"]:
|
||||||
|
# Start cancelling the batch asynchronously
|
||||||
|
asyncio.create_task(
|
||||||
|
cancel_batch(
|
||||||
|
tokenizer_manager=tokenizer_manager,
|
||||||
|
batch_id=batch_id,
|
||||||
|
input_file_id=batch_response.input_file_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update batch status to "cancelling"
|
||||||
|
batch_response.status = "cancelling"
|
||||||
|
|
||||||
|
return batch_response
|
||||||
|
else:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail=f"Current status is {batch_response.status}, no need to cancel",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
|
||||||
|
try:
|
||||||
|
# Update the batch status to "cancelling"
|
||||||
|
batch_storage[batch_id].status = "cancelling"
|
||||||
|
|
||||||
|
# Retrieve the input file content
|
||||||
|
input_file_request = file_id_request.get(input_file_id)
|
||||||
|
if not input_file_request:
|
||||||
|
raise ValueError("Input file not found")
|
||||||
|
|
||||||
|
# Parse the JSONL file and process each request
|
||||||
|
input_file_path = file_id_storage.get(input_file_id)
|
||||||
|
with open(input_file_path, "r", encoding="utf-8") as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
|
||||||
|
file_request_list = []
|
||||||
|
request_ids = []
|
||||||
|
for line in lines:
|
||||||
|
request_data = json.loads(line)
|
||||||
|
file_request_list.append(request_data)
|
||||||
|
request_ids.append(request_data["custom_id"])
|
||||||
|
|
||||||
|
# Cancel requests by request_ids
|
||||||
|
for rid in request_ids:
|
||||||
|
tokenizer_manager.abort_request(rid=rid)
|
||||||
|
|
||||||
|
retrieve_batch = batch_storage[batch_id]
|
||||||
|
retrieve_batch.status = "cancelled"
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error("error in SGLang:", e)
|
||||||
|
# Update batch status to "failed"
|
||||||
|
retrieve_batch = batch_storage[batch_id]
|
||||||
|
retrieve_batch.status = "failed"
|
||||||
|
retrieve_batch.failed_at = int(time.time())
|
||||||
|
retrieve_batch.errors = {"message": str(e)}
|
||||||
|
|
||||||
|
|
||||||
async def v1_retrieve_file(file_id: str):
|
async def v1_retrieve_file(file_id: str):
|
||||||
# Retrieve the batch job from the in-memory storage
|
# Retrieve the batch job from the in-memory storage
|
||||||
file_response = file_id_response.get(file_id)
|
file_response = file_id_response.get(file_id)
|
||||||
@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str):
|
|||||||
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
return StreamingResponse(iter_file(), media_type="application/octet-stream")
|
||||||
|
|
||||||
|
|
||||||
def v1_generate_request(all_requests: List[CompletionRequest]):
|
def v1_generate_request(
|
||||||
|
all_requests: List[CompletionRequest], request_ids: List[str] = None
|
||||||
|
):
|
||||||
prompts = []
|
prompts = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
return_logprobs = []
|
return_logprobs = []
|
||||||
@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
|
|||||||
logprob_start_len=logprob_start_lens,
|
logprob_start_len=logprob_start_lens,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
|
rid=request_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|||||||
|
|
||||||
|
|
||||||
def v1_chat_generate_request(
|
def v1_chat_generate_request(
|
||||||
all_requests: List[ChatCompletionRequest], tokenizer_manager
|
all_requests: List[ChatCompletionRequest],
|
||||||
|
tokenizer_manager,
|
||||||
|
request_ids: List[str] = None,
|
||||||
):
|
):
|
||||||
input_ids = []
|
input_ids = []
|
||||||
sampling_params_list = []
|
sampling_params_list = []
|
||||||
@@ -834,6 +912,7 @@ def v1_chat_generate_request(
|
|||||||
top_logprobs_num=top_logprobs_nums,
|
top_logprobs_num=top_logprobs_nums,
|
||||||
stream=all_requests[0].stream,
|
stream=all_requests[0].stream,
|
||||||
return_text_in_logprobs=True,
|
return_text_in_logprobs=True,
|
||||||
|
rid=request_ids,
|
||||||
)
|
)
|
||||||
if len(all_requests) == 1:
|
if len(all_requests) == 1:
|
||||||
return adapted_request, all_requests[0]
|
return adapted_request, all_requests[0]
|
||||||
|
|||||||
@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
|||||||
from sglang.srt.openai_api.adapter import (
|
from sglang.srt.openai_api.adapter import (
|
||||||
load_chat_template_for_openai_api,
|
load_chat_template_for_openai_api,
|
||||||
v1_batches,
|
v1_batches,
|
||||||
|
v1_cancel_batch,
|
||||||
v1_chat_completions,
|
v1_chat_completions,
|
||||||
v1_completions,
|
v1_completions,
|
||||||
v1_delete_file,
|
v1_delete_file,
|
||||||
@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
|
|||||||
return await v1_batches(tokenizer_manager, raw_request)
|
return await v1_batches(tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/v1/batches/{batch_id}/cancel")
|
||||||
|
async def cancel_batches(batch_id: str):
|
||||||
|
# https://platform.openai.com/docs/api-reference/batch/cancel
|
||||||
|
return await v1_cancel_batch(tokenizer_manager, batch_id)
|
||||||
|
|
||||||
|
|
||||||
@app.get("/v1/batches/{batch_id}")
|
@app.get("/v1/batches/{batch_id}")
|
||||||
async def retrieve_batch(batch_id: str):
|
async def retrieve_batch(batch_id: str):
|
||||||
return await v1_retrieve_batch(batch_id)
|
return await v1_retrieve_batch(batch_id)
|
||||||
|
|||||||
@@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
index, True
|
index, True
|
||||||
), f"index {index} is not found in the response"
|
), f"index {index} is not found in the response"
|
||||||
|
|
||||||
def run_batch(self, mode):
|
def _create_batch(self, mode, client):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
|
||||||
if mode == "completion":
|
if mode == "completion":
|
||||||
input_file_path = "complete_input.jsonl"
|
input_file_path = "complete_input.jsonl"
|
||||||
# write content to input file
|
# write content to input file
|
||||||
@@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
with open(input_file_path, "w") as file:
|
with open(input_file_path, "w") as file:
|
||||||
for line in content:
|
for line in content:
|
||||||
file.write(json.dumps(line) + "\n")
|
file.write(json.dumps(line) + "\n")
|
||||||
|
|
||||||
with open(input_file_path, "rb") as file:
|
with open(input_file_path, "rb") as file:
|
||||||
uploaded_file = client.files.create(file=file, purpose="batch")
|
uploaded_file = client.files.create(file=file, purpose="batch")
|
||||||
if mode == "completion":
|
if mode == "completion":
|
||||||
@@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
endpoint=endpoint,
|
endpoint=endpoint,
|
||||||
completion_window=completion_window,
|
completion_window=completion_window,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
return batch_job, content
|
||||||
|
|
||||||
|
def run_batch(self, mode):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
batch_job, content = self._create_batch(mode=mode, client=client)
|
||||||
|
|
||||||
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
while batch_job.status not in ["completed", "failed", "cancelled"]:
|
||||||
time.sleep(3)
|
time.sleep(3)
|
||||||
print(
|
print(
|
||||||
@@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
]
|
]
|
||||||
assert len(results) == len(content)
|
assert len(results) == len(content)
|
||||||
|
|
||||||
|
def run_cancel_batch(self, mode):
|
||||||
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
batch_job, _ = self._create_batch(mode=mode, client=client)
|
||||||
|
|
||||||
|
assert batch_job.status not in ["cancelling", "cancelled"]
|
||||||
|
|
||||||
|
batch_job = client.batches.cancel(batch_id=batch_job.id)
|
||||||
|
assert batch_job.status == "cancelling"
|
||||||
|
|
||||||
|
while batch_job.status not in ["failed", "cancelled"]:
|
||||||
|
batch_job = client.batches.retrieve(batch_job.id)
|
||||||
|
print(
|
||||||
|
f"Batch job status: {batch_job.status}...trying again in 3 seconds..."
|
||||||
|
)
|
||||||
|
time.sleep(3)
|
||||||
|
|
||||||
|
assert batch_job.status == "cancelled"
|
||||||
|
|
||||||
def test_completion(self):
|
def test_completion(self):
|
||||||
for echo in [False, True]:
|
for echo in [False, True]:
|
||||||
for logprobs in [None, 5]:
|
for logprobs in [None, 5]:
|
||||||
@@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase):
|
|||||||
for mode in ["completion", "chat"]:
|
for mode in ["completion", "chat"]:
|
||||||
self.run_batch(mode)
|
self.run_batch(mode)
|
||||||
|
|
||||||
|
def test_calcel_batch(self):
|
||||||
|
for mode in ["completion", "chat"]:
|
||||||
|
self.run_cancel_batch(mode)
|
||||||
|
|
||||||
def test_regex(self):
|
def test_regex(self):
|
||||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user