[FEAT] Support batches cancel (#1222)

Co-authored-by: Yineng Zhang <me@zhyncs.com>
This commit is contained in:
caiyueliang
2024-08-27 07:28:26 +08:00
committed by GitHub
parent c61a1b6f97
commit 2f1d92834f
3 changed files with 122 additions and 6 deletions

View File

@@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
end_point = batch_storage[batch_id].endpoint
file_request_list = []
all_requests = []
request_ids = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
body = request_data["body"]
request_ids.append(request_data["custom_id"])
# Although streaming is supported for standalone completions, it is not supported in
# batch mode (multiple completions in single request).
@@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
all_requests.append(ChatCompletionRequest(**body))
elif end_point == "/v1/completions":
all_requests.append(CompletionRequest(**body))
if end_point == "/v1/chat/completions":
adapted_request, request = v1_chat_generate_request(
all_requests, tokenizer_manager
all_requests, tokenizer_manager, request_ids=request_ids
)
elif end_point == "/v1/completions":
adapted_request, request = v1_generate_request(all_requests)
adapted_request, request = v1_generate_request(
all_requests, request_ids=request_ids
)
try:
ret = await tokenizer_manager.generate_request(adapted_request).__anext__()
if not isinstance(ret, list):
@@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
}
all_ret.append(response_json)
completed_requests += 1
# Write results to a new file
output_file_id = f"backend_result_file-{uuid.uuid4()}"
global storage_dir
@@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str):
return batch_response
async def v1_cancel_batch(tokenizer_manager, batch_id: str):
# Retrieve the batch job from the in-memory storage
batch_response = batch_storage.get(batch_id)
if batch_response is None:
raise HTTPException(status_code=404, detail="Batch not found")
# Only do cancal when status is "validating" or "in_progress"
if batch_response.status in ["validating", "in_progress"]:
# Start cancelling the batch asynchronously
asyncio.create_task(
cancel_batch(
tokenizer_manager=tokenizer_manager,
batch_id=batch_id,
input_file_id=batch_response.input_file_id,
)
)
# Update batch status to "cancelling"
batch_response.status = "cancelling"
return batch_response
else:
raise HTTPException(
status_code=500,
detail=f"Current status is {batch_response.status}, no need to cancel",
)
async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str):
try:
# Update the batch status to "cancelling"
batch_storage[batch_id].status = "cancelling"
# Retrieve the input file content
input_file_request = file_id_request.get(input_file_id)
if not input_file_request:
raise ValueError("Input file not found")
# Parse the JSONL file and process each request
input_file_path = file_id_storage.get(input_file_id)
with open(input_file_path, "r", encoding="utf-8") as f:
lines = f.readlines()
file_request_list = []
request_ids = []
for line in lines:
request_data = json.loads(line)
file_request_list.append(request_data)
request_ids.append(request_data["custom_id"])
# Cancel requests by request_ids
for rid in request_ids:
tokenizer_manager.abort_request(rid=rid)
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "cancelled"
except Exception as e:
logger.error("error in SGLang:", e)
# Update batch status to "failed"
retrieve_batch = batch_storage[batch_id]
retrieve_batch.status = "failed"
retrieve_batch.failed_at = int(time.time())
retrieve_batch.errors = {"message": str(e)}
async def v1_retrieve_file(file_id: str):
# Retrieve the batch job from the in-memory storage
file_response = file_id_response.get(file_id)
@@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str):
return StreamingResponse(iter_file(), media_type="application/octet-stream")
def v1_generate_request(all_requests: List[CompletionRequest]):
def v1_generate_request(
all_requests: List[CompletionRequest], request_ids: List[str] = None
):
prompts = []
sampling_params_list = []
return_logprobs = []
@@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]):
logprob_start_len=logprob_start_lens,
return_text_in_logprobs=True,
stream=all_requests[0].stream,
rid=request_ids,
)
if len(all_requests) == 1:
@@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
def v1_chat_generate_request(
all_requests: List[ChatCompletionRequest], tokenizer_manager
all_requests: List[ChatCompletionRequest],
tokenizer_manager,
request_ids: List[str] = None,
):
input_ids = []
sampling_params_list = []
@@ -834,6 +912,7 @@ def v1_chat_generate_request(
top_logprobs_num=top_logprobs_nums,
stream=all_requests[0].stream,
return_text_in_logprobs=True,
rid=request_ids,
)
if len(all_requests) == 1:
return adapted_request, all_requests[0]

View File

@@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import (
load_chat_template_for_openai_api,
v1_batches,
v1_cancel_batch,
v1_chat_completions,
v1_completions,
v1_delete_file,
@@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request):
return await v1_batches(tokenizer_manager, raw_request)
@app.post("/v1/batches/{batch_id}/cancel")
async def cancel_batches(batch_id: str):
# https://platform.openai.com/docs/api-reference/batch/cancel
return await v1_cancel_batch(tokenizer_manager, batch_id)
@app.get("/v1/batches/{batch_id}")
async def retrieve_batch(batch_id: str):
return await v1_retrieve_batch(batch_id)