From bc20e93f2df15f81e624cd624e693406180e9178 Mon Sep 17 00:00:00 2001 From: KCFindstr Date: Thu, 27 Feb 2025 19:42:15 -0800 Subject: [PATCH] [feat] Add Vertex AI compatible prediction route for /generate (#3866) --- examples/runtime/vertex_predict.py | 66 ++++++++++++++++++++ python/sglang/srt/entrypoints/http_server.py | 27 ++++++++ python/sglang/srt/managers/io_struct.py | 6 ++ test/srt/run_suite.py | 1 + test/srt/test_vertex_endpoint.py | 52 +++++++++++++++ 5 files changed, 152 insertions(+) create mode 100644 examples/runtime/vertex_predict.py create mode 100644 test/srt/test_vertex_endpoint.py diff --git a/examples/runtime/vertex_predict.py b/examples/runtime/vertex_predict.py new file mode 100644 index 000000000..58a41b1c4 --- /dev/null +++ b/examples/runtime/vertex_predict.py @@ -0,0 +1,66 @@ +""" +Usage: +python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000 +python vertex_predict.py + +This example shows the request and response formats of the prediction route for +Google Cloud Vertex AI Online Predictions. + +Vertex AI SDK for Python is recommended for deploying models to Vertex AI +instead of a local server. After deploying the model to a Vertex AI Online +Prediction Endpoint, send requests via the Python SDK: + +response = endpoint.predict( + instances=[ + {"text": "The capital of France is"}, + {"text": "What is a car?"}, + ], + parameters={"sampling_params": {"max_new_tokens": 16}}, +) +print(response.predictions) + +More details about get online predictions from Vertex AI can be found at +https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions. +""" + +from dataclasses import dataclass +from typing import List, Optional + +import requests + + +@dataclass +class VertexPrediction: + predictions: List + + +class LocalVertexEndpoint: + def __init__(self) -> None: + self.base_url = "http://127.0.0.1:30000" + + def predict(self, instances: List[dict], parameters: Optional[dict] = None): + response = requests.post( + self.base_url + "/vertex_generate", + json={ + "instances": instances, + "parameters": parameters, + }, + ) + return VertexPrediction(predictions=response.json()["predictions"]) + + +endpoint = LocalVertexEndpoint() + +# Predict with a single prompt. +response = endpoint.predict(instances=[{"text": "The capital of France is"}]) +print(response.predictions) + +# Predict with multiple prompts and parameters. +response = endpoint.predict( + instances=[ + {"text": "The capital of France is"}, + {"text": "What is a car?"}, + ], + parameters={"sampling_params": {"max_new_tokens": 16}}, +) +print(response.predictions) diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 2b2421a37..f84089d05 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import ( ResumeMemoryOccupationReqInput, UpdateWeightFromDiskReqInput, UpdateWeightsFromDistributedReqInput, + VertexGenerateReqInput, ) from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.metrics.func_timer import enable_func_timer @@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request): return await v1_chat_completions(_global_state.tokenizer_manager, raw_request) +## Vertex AI API +@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate")) +async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request): + if not vertex_req.instances: + return [] + inputs = {} + for input_key in ("text", "input_ids", "input_embeds"): + if vertex_req.instances[0].get(input_key): + inputs[input_key] = [ + instance.get(input_key) for instance in vertex_req.instances + ] + break + image_data = [ + instance.get("image_data") + for instance in vertex_req.instances + if instance.get("image_data") is not None + ] or None + req = GenerateReqInput( + **inputs, + image_data=image_data, + **(vertex_req.parameters or {}), + ) + ret = await generate_request(req, raw_request) + return ORJSONResponse({"predictions": ret}) + + def _create_error_response(e): return ORJSONResponse( {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST diff --git a/python/sglang/srt/managers/io_struct.py b/python/sglang/srt/managers/io_struct.py index 67225cf84..1c7be3053 100644 --- a/python/sglang/srt/managers/io_struct.py +++ b/python/sglang/srt/managers/io_struct.py @@ -568,3 +568,9 @@ class FunctionCallReqInput: tool_call_parser: Optional[str] = ( None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all. ) + + +@dataclass +class VertexGenerateReqInput: + instances: List[dict] + parameters: Optional[dict] = None diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 506b87bf6..b02bbec56 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -50,6 +50,7 @@ suites = { "test_hidden_states.py", "test_update_weights_from_disk.py", "test_update_weights_from_tensor.py", + "test_vertex_endpoint.py", "test_vision_chunked_prefill.py", "test_vision_llm.py", "test_vision_openai_server.py", diff --git a/test/srt/test_vertex_endpoint.py b/test/srt/test_vertex_endpoint.py new file mode 100644 index 000000000..728d0d1d2 --- /dev/null +++ b/test/srt/test_vertex_endpoint.py @@ -0,0 +1,52 @@ +""" +python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate +""" + +import unittest + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.test_utils import ( + DEFAULT_SMALL_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +class TestVertexEndpoint(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def run_generate(self, parameters): + data = { + "instances": [ + {"text": "The capital of France is"}, + {"text": "The capital of China is"}, + ], + "parameters": parameters, + } + response = requests.post(self.base_url + "/vertex_generate", json=data) + response_json = response.json() + assert len(response_json["predictions"]) == len(data["instances"]) + return response_json + + def test_vertex_generate(self): + for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]: + self.run_generate(parameters) + + +if __name__ == "__main__": + unittest.main()