diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 147ab4131..1064d2dbf 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -675,6 +675,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque **(vertex_req.parameters or {}), ) ret = await generate_request(req, raw_request) + if isinstance(ret, Response): + return ret return ORJSONResponse({"predictions": ret}) diff --git a/test/srt/test_vertex_endpoint.py b/test/srt/test_vertex_endpoint.py index a899d6251..42e48cb1b 100644 --- a/test/srt/test_vertex_endpoint.py +++ b/test/srt/test_vertex_endpoint.py @@ -3,6 +3,7 @@ python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate """ import unittest +from http import HTTPStatus import requests @@ -49,6 +50,15 @@ class TestVertexEndpoint(CustomTestCase): for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]: self.run_generate(parameters) + def test_vertex_generate_fail(self): + data = { + "instances": [ + {"prompt": "The capital of France is"}, + ], + } + response = requests.post(self.base_url + "/vertex_generate", json=data) + assert response.status_code == HTTPStatus.BAD_REQUEST + if __name__ == "__main__": unittest.main()