[feat] Add Vertex AI compatible prediction route for /generate (#3866)
This commit is contained in:
66
examples/runtime/vertex_predict.py
Normal file
66
examples/runtime/vertex_predict.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""
|
||||||
|
Usage:
|
||||||
|
python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000
|
||||||
|
python vertex_predict.py
|
||||||
|
|
||||||
|
This example shows the request and response formats of the prediction route for
|
||||||
|
Google Cloud Vertex AI Online Predictions.
|
||||||
|
|
||||||
|
Vertex AI SDK for Python is recommended for deploying models to Vertex AI
|
||||||
|
instead of a local server. After deploying the model to a Vertex AI Online
|
||||||
|
Prediction Endpoint, send requests via the Python SDK:
|
||||||
|
|
||||||
|
response = endpoint.predict(
|
||||||
|
instances=[
|
||||||
|
{"text": "The capital of France is"},
|
||||||
|
{"text": "What is a car?"},
|
||||||
|
],
|
||||||
|
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||||
|
)
|
||||||
|
print(response.predictions)
|
||||||
|
|
||||||
|
More details about get online predictions from Vertex AI can be found at
|
||||||
|
https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VertexPrediction:
|
||||||
|
predictions: List
|
||||||
|
|
||||||
|
|
||||||
|
class LocalVertexEndpoint:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self.base_url = "http://127.0.0.1:30000"
|
||||||
|
|
||||||
|
def predict(self, instances: List[dict], parameters: Optional[dict] = None):
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/vertex_generate",
|
||||||
|
json={
|
||||||
|
"instances": instances,
|
||||||
|
"parameters": parameters,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return VertexPrediction(predictions=response.json()["predictions"])
|
||||||
|
|
||||||
|
|
||||||
|
endpoint = LocalVertexEndpoint()
|
||||||
|
|
||||||
|
# Predict with a single prompt.
|
||||||
|
response = endpoint.predict(instances=[{"text": "The capital of France is"}])
|
||||||
|
print(response.predictions)
|
||||||
|
|
||||||
|
# Predict with multiple prompts and parameters.
|
||||||
|
response = endpoint.predict(
|
||||||
|
instances=[
|
||||||
|
{"text": "The capital of France is"},
|
||||||
|
{"text": "What is a car?"},
|
||||||
|
],
|
||||||
|
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||||
|
)
|
||||||
|
print(response.predictions)
|
||||||
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
UpdateWeightsFromDistributedReqInput,
|
UpdateWeightsFromDistributedReqInput,
|
||||||
|
VertexGenerateReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||||
@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
|
|||||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||||
|
|
||||||
|
|
||||||
|
## Vertex AI API
|
||||||
|
@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate"))
|
||||||
|
async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):
|
||||||
|
if not vertex_req.instances:
|
||||||
|
return []
|
||||||
|
inputs = {}
|
||||||
|
for input_key in ("text", "input_ids", "input_embeds"):
|
||||||
|
if vertex_req.instances[0].get(input_key):
|
||||||
|
inputs[input_key] = [
|
||||||
|
instance.get(input_key) for instance in vertex_req.instances
|
||||||
|
]
|
||||||
|
break
|
||||||
|
image_data = [
|
||||||
|
instance.get("image_data")
|
||||||
|
for instance in vertex_req.instances
|
||||||
|
if instance.get("image_data") is not None
|
||||||
|
] or None
|
||||||
|
req = GenerateReqInput(
|
||||||
|
**inputs,
|
||||||
|
image_data=image_data,
|
||||||
|
**(vertex_req.parameters or {}),
|
||||||
|
)
|
||||||
|
ret = await generate_request(req, raw_request)
|
||||||
|
return ORJSONResponse({"predictions": ret})
|
||||||
|
|
||||||
|
|
||||||
def _create_error_response(e):
|
def _create_error_response(e):
|
||||||
return ORJSONResponse(
|
return ORJSONResponse(
|
||||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
|
|||||||
@@ -568,3 +568,9 @@ class FunctionCallReqInput:
|
|||||||
tool_call_parser: Optional[str] = (
|
tool_call_parser: Optional[str] = (
|
||||||
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class VertexGenerateReqInput:
|
||||||
|
instances: List[dict]
|
||||||
|
parameters: Optional[dict] = None
|
||||||
|
|||||||
@@ -50,6 +50,7 @@ suites = {
|
|||||||
"test_hidden_states.py",
|
"test_hidden_states.py",
|
||||||
"test_update_weights_from_disk.py",
|
"test_update_weights_from_disk.py",
|
||||||
"test_update_weights_from_tensor.py",
|
"test_update_weights_from_tensor.py",
|
||||||
|
"test_vertex_endpoint.py",
|
||||||
"test_vision_chunked_prefill.py",
|
"test_vision_chunked_prefill.py",
|
||||||
"test_vision_llm.py",
|
"test_vision_llm.py",
|
||||||
"test_vision_openai_server.py",
|
"test_vision_openai_server.py",
|
||||||
|
|||||||
52
test/srt/test_vertex_endpoint.py
Normal file
52
test/srt/test_vertex_endpoint.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
|
||||||
|
"""
|
||||||
|
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from sglang.srt.utils import kill_process_tree
|
||||||
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
DEFAULT_URL_FOR_TEST,
|
||||||
|
popen_launch_server,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestVertexEndpoint(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
cls.model,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def run_generate(self, parameters):
|
||||||
|
data = {
|
||||||
|
"instances": [
|
||||||
|
{"text": "The capital of France is"},
|
||||||
|
{"text": "The capital of China is"},
|
||||||
|
],
|
||||||
|
"parameters": parameters,
|
||||||
|
}
|
||||||
|
response = requests.post(self.base_url + "/vertex_generate", json=data)
|
||||||
|
response_json = response.json()
|
||||||
|
assert len(response_json["predictions"]) == len(data["instances"])
|
||||||
|
return response_json
|
||||||
|
|
||||||
|
def test_vertex_generate(self):
|
||||||
|
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
|
||||||
|
self.run_generate(parameters)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user