diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 148f2689d..4feb632b0 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -275,10 +275,12 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe end_point = batch_storage[batch_id].endpoint file_request_list = [] all_requests = [] + request_ids = [] for line in lines: request_data = json.loads(line) file_request_list.append(request_data) body = request_data["body"] + request_ids.append(request_data["custom_id"]) # Although streaming is supported for standalone completions, it is not supported in # batch mode (multiple completions in single request). @@ -289,12 +291,16 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe all_requests.append(ChatCompletionRequest(**body)) elif end_point == "/v1/completions": all_requests.append(CompletionRequest(**body)) + if end_point == "/v1/chat/completions": adapted_request, request = v1_chat_generate_request( - all_requests, tokenizer_manager + all_requests, tokenizer_manager, request_ids=request_ids ) elif end_point == "/v1/completions": - adapted_request, request = v1_generate_request(all_requests) + adapted_request, request = v1_generate_request( + all_requests, request_ids=request_ids + ) + try: ret = await tokenizer_manager.generate_request(adapted_request).__anext__() if not isinstance(ret, list): @@ -326,6 +332,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe } all_ret.append(response_json) completed_requests += 1 + # Write results to a new file output_file_id = f"backend_result_file-{uuid.uuid4()}" global storage_dir @@ -372,6 +379,72 @@ async def v1_retrieve_batch(batch_id: str): return batch_response +async def v1_cancel_batch(tokenizer_manager, batch_id: str): + # Retrieve the batch job from the in-memory storage + batch_response = batch_storage.get(batch_id) + if batch_response is None: + raise HTTPException(status_code=404, detail="Batch not found") + + # Only do cancal when status is "validating" or "in_progress" + if batch_response.status in ["validating", "in_progress"]: + # Start cancelling the batch asynchronously + asyncio.create_task( + cancel_batch( + tokenizer_manager=tokenizer_manager, + batch_id=batch_id, + input_file_id=batch_response.input_file_id, + ) + ) + + # Update batch status to "cancelling" + batch_response.status = "cancelling" + + return batch_response + else: + raise HTTPException( + status_code=500, + detail=f"Current status is {batch_response.status}, no need to cancel", + ) + + +async def cancel_batch(tokenizer_manager, batch_id: str, input_file_id: str): + try: + # Update the batch status to "cancelling" + batch_storage[batch_id].status = "cancelling" + + # Retrieve the input file content + input_file_request = file_id_request.get(input_file_id) + if not input_file_request: + raise ValueError("Input file not found") + + # Parse the JSONL file and process each request + input_file_path = file_id_storage.get(input_file_id) + with open(input_file_path, "r", encoding="utf-8") as f: + lines = f.readlines() + + file_request_list = [] + request_ids = [] + for line in lines: + request_data = json.loads(line) + file_request_list.append(request_data) + request_ids.append(request_data["custom_id"]) + + # Cancel requests by request_ids + for rid in request_ids: + tokenizer_manager.abort_request(rid=rid) + + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "cancelled" + + except Exception as e: + logger.error("error in SGLang:", e) + # Update batch status to "failed" + retrieve_batch = batch_storage[batch_id] + retrieve_batch.status = "failed" + retrieve_batch.failed_at = int(time.time()) + retrieve_batch.errors = {"message": str(e)} + + async def v1_retrieve_file(file_id: str): # Retrieve the batch job from the in-memory storage file_response = file_id_response.get(file_id) @@ -392,7 +465,9 @@ async def v1_retrieve_file_content(file_id: str): return StreamingResponse(iter_file(), media_type="application/octet-stream") -def v1_generate_request(all_requests: List[CompletionRequest]): +def v1_generate_request( + all_requests: List[CompletionRequest], request_ids: List[str] = None +): prompts = [] sampling_params_list = [] return_logprobs = [] @@ -464,6 +539,7 @@ def v1_generate_request(all_requests: List[CompletionRequest]): logprob_start_len=logprob_start_lens, return_text_in_logprobs=True, stream=all_requests[0].stream, + rid=request_ids, ) if len(all_requests) == 1: @@ -746,7 +822,9 @@ async def v1_completions(tokenizer_manager, raw_request: Request): def v1_chat_generate_request( - all_requests: List[ChatCompletionRequest], tokenizer_manager + all_requests: List[ChatCompletionRequest], + tokenizer_manager, + request_ids: List[str] = None, ): input_ids = [] sampling_params_list = [] @@ -834,6 +912,7 @@ def v1_chat_generate_request( top_logprobs_num=top_logprobs_nums, stream=all_requests[0].stream, return_text_in_logprobs=True, + rid=request_ids, ) if len(all_requests) == 1: return adapted_request, all_requests[0] diff --git a/python/sglang/srt/server.py b/python/sglang/srt/server.py index 021f231aa..6d1fc9fda 100644 --- a/python/sglang/srt/server.py +++ b/python/sglang/srt/server.py @@ -59,6 +59,7 @@ from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.openai_api.adapter import ( load_chat_template_for_openai_api, v1_batches, + v1_cancel_batch, v1_chat_completions, v1_completions, v1_delete_file, @@ -246,6 +247,12 @@ async def openai_v1_batches(raw_request: Request): return await v1_batches(tokenizer_manager, raw_request) +@app.post("/v1/batches/{batch_id}/cancel") +async def cancel_batches(batch_id: str): + # https://platform.openai.com/docs/api-reference/batch/cancel + return await v1_cancel_batch(tokenizer_manager, batch_id) + + @app.get("/v1/batches/{batch_id}") async def retrieve_batch(batch_id: str): return await v1_retrieve_batch(batch_id) diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index ce130956d..cfc65b7e6 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -256,8 +256,7 @@ class TestOpenAIServer(unittest.TestCase): index, True ), f"index {index} is not found in the response" - def run_batch(self, mode): - client = openai.Client(api_key=self.api_key, base_url=self.base_url) + def _create_batch(self, mode, client): if mode == "completion": input_file_path = "complete_input.jsonl" # write content to input file @@ -333,9 +332,11 @@ class TestOpenAIServer(unittest.TestCase): }, }, ] + with open(input_file_path, "w") as file: for line in content: file.write(json.dumps(line) + "\n") + with open(input_file_path, "rb") as file: uploaded_file = client.files.create(file=file, purpose="batch") if mode == "completion": @@ -348,6 +349,13 @@ class TestOpenAIServer(unittest.TestCase): endpoint=endpoint, completion_window=completion_window, ) + + return batch_job, content + + def run_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, content = self._create_batch(mode=mode, client=client) + while batch_job.status not in ["completed", "failed", "cancelled"]: time.sleep(3) print( @@ -371,6 +379,24 @@ class TestOpenAIServer(unittest.TestCase): ] assert len(results) == len(content) + def run_cancel_batch(self, mode): + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + batch_job, _ = self._create_batch(mode=mode, client=client) + + assert batch_job.status not in ["cancelling", "cancelled"] + + batch_job = client.batches.cancel(batch_id=batch_job.id) + assert batch_job.status == "cancelling" + + while batch_job.status not in ["failed", "cancelled"]: + batch_job = client.batches.retrieve(batch_job.id) + print( + f"Batch job status: {batch_job.status}...trying again in 3 seconds..." + ) + time.sleep(3) + + assert batch_job.status == "cancelled" + def test_completion(self): for echo in [False, True]: for logprobs in [None, 5]: @@ -414,6 +440,10 @@ class TestOpenAIServer(unittest.TestCase): for mode in ["completion", "chat"]: self.run_batch(mode) + def test_calcel_batch(self): + for mode in ["completion", "chat"]: + self.run_cancel_batch(mode) + def test_regex(self): client = openai.Client(api_key=self.api_key, base_url=self.base_url)