[feat] Add Vertex AI compatible prediction route for /generate (#3866)
This commit is contained in:
66
examples/runtime/vertex_predict.py
Normal file
66
examples/runtime/vertex_predict.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
Usage:
|
||||
python -m sglang.launch_server --model meta-llama/Llama-2-7b-hf --port 30000
|
||||
python vertex_predict.py
|
||||
|
||||
This example shows the request and response formats of the prediction route for
|
||||
Google Cloud Vertex AI Online Predictions.
|
||||
|
||||
Vertex AI SDK for Python is recommended for deploying models to Vertex AI
|
||||
instead of a local server. After deploying the model to a Vertex AI Online
|
||||
Prediction Endpoint, send requests via the Python SDK:
|
||||
|
||||
response = endpoint.predict(
|
||||
instances=[
|
||||
{"text": "The capital of France is"},
|
||||
{"text": "What is a car?"},
|
||||
],
|
||||
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||
)
|
||||
print(response.predictions)
|
||||
|
||||
More details about get online predictions from Vertex AI can be found at
|
||||
https://cloud.google.com/vertex-ai/docs/predictions/get-online-predictions.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexPrediction:
|
||||
predictions: List
|
||||
|
||||
|
||||
class LocalVertexEndpoint:
|
||||
def __init__(self) -> None:
|
||||
self.base_url = "http://127.0.0.1:30000"
|
||||
|
||||
def predict(self, instances: List[dict], parameters: Optional[dict] = None):
|
||||
response = requests.post(
|
||||
self.base_url + "/vertex_generate",
|
||||
json={
|
||||
"instances": instances,
|
||||
"parameters": parameters,
|
||||
},
|
||||
)
|
||||
return VertexPrediction(predictions=response.json()["predictions"])
|
||||
|
||||
|
||||
endpoint = LocalVertexEndpoint()
|
||||
|
||||
# Predict with a single prompt.
|
||||
response = endpoint.predict(instances=[{"text": "The capital of France is"}])
|
||||
print(response.predictions)
|
||||
|
||||
# Predict with multiple prompts and parameters.
|
||||
response = endpoint.predict(
|
||||
instances=[
|
||||
{"text": "The capital of France is"},
|
||||
{"text": "What is a car?"},
|
||||
],
|
||||
parameters={"sampling_params": {"max_new_tokens": 16}},
|
||||
)
|
||||
print(response.predictions)
|
||||
@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
UpdateWeightsFromDistributedReqInput,
|
||||
VertexGenerateReqInput,
|
||||
)
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.metrics.func_timer import enable_func_timer
|
||||
@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
|
||||
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
|
||||
|
||||
|
||||
## Vertex AI API
|
||||
@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate"))
|
||||
async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):
|
||||
if not vertex_req.instances:
|
||||
return []
|
||||
inputs = {}
|
||||
for input_key in ("text", "input_ids", "input_embeds"):
|
||||
if vertex_req.instances[0].get(input_key):
|
||||
inputs[input_key] = [
|
||||
instance.get(input_key) for instance in vertex_req.instances
|
||||
]
|
||||
break
|
||||
image_data = [
|
||||
instance.get("image_data")
|
||||
for instance in vertex_req.instances
|
||||
if instance.get("image_data") is not None
|
||||
] or None
|
||||
req = GenerateReqInput(
|
||||
**inputs,
|
||||
image_data=image_data,
|
||||
**(vertex_req.parameters or {}),
|
||||
)
|
||||
ret = await generate_request(req, raw_request)
|
||||
return ORJSONResponse({"predictions": ret})
|
||||
|
||||
|
||||
def _create_error_response(e):
|
||||
return ORJSONResponse(
|
||||
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
||||
|
||||
@@ -568,3 +568,9 @@ class FunctionCallReqInput:
|
||||
tool_call_parser: Optional[str] = (
|
||||
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class VertexGenerateReqInput:
|
||||
instances: List[dict]
|
||||
parameters: Optional[dict] = None
|
||||
|
||||
@@ -50,6 +50,7 @@ suites = {
|
||||
"test_hidden_states.py",
|
||||
"test_update_weights_from_disk.py",
|
||||
"test_update_weights_from_tensor.py",
|
||||
"test_vertex_endpoint.py",
|
||||
"test_vision_chunked_prefill.py",
|
||||
"test_vision_llm.py",
|
||||
"test_vision_openai_server.py",
|
||||
|
||||
52
test/srt/test_vertex_endpoint.py
Normal file
52
test/srt/test_vertex_endpoint.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""
|
||||
python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
|
||||
"""
|
||||
|
||||
import unittest
|
||||
|
||||
import requests
|
||||
|
||||
from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
popen_launch_server,
|
||||
)
|
||||
|
||||
|
||||
class TestVertexEndpoint(unittest.TestCase):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
kill_process_tree(cls.process.pid)
|
||||
|
||||
def run_generate(self, parameters):
|
||||
data = {
|
||||
"instances": [
|
||||
{"text": "The capital of France is"},
|
||||
{"text": "The capital of China is"},
|
||||
],
|
||||
"parameters": parameters,
|
||||
}
|
||||
response = requests.post(self.base_url + "/vertex_generate", json=data)
|
||||
response_json = response.json()
|
||||
assert len(response_json["predictions"]) == len(data["instances"])
|
||||
return response_json
|
||||
|
||||
def test_vertex_generate(self):
|
||||
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
|
||||
self.run_generate(parameters)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user