Properly return error response in vertex_generate HTTP endpoint (#5956)

This commit is contained in:
KCFindstr
2025-05-01 11:48:58 -07:00
committed by GitHub
parent 6fc175968c
commit d33955d28a
2 changed files with 12 additions and 0 deletions

View File

@@ -675,6 +675,8 @@ async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Reque
**(vertex_req.parameters or {}),
)
ret = await generate_request(req, raw_request)
if isinstance(ret, Response):
return ret
return ORJSONResponse({"predictions": ret})

View File

@@ -3,6 +3,7 @@ python3 -m unittest test_vertex_endpoint.TestVertexEndpoint.test_vertex_generate
"""
import unittest
from http import HTTPStatus
import requests
@@ -49,6 +50,15 @@ class TestVertexEndpoint(CustomTestCase):
for parameters in [None, {"sampling_params": {"max_new_tokens": 4}}]:
self.run_generate(parameters)
def test_vertex_generate_fail(self):
data = {
"instances": [
{"prompt": "The capital of France is"},
],
}
response = requests.post(self.base_url + "/vertex_generate", json=data)
assert response.status_code == HTTPStatus.BAD_REQUEST
if __name__ == "__main__":
unittest.main()