[feat] Add Vertex AI compatible prediction route for /generate (#3866)

This commit is contained in:
KCFindstr
2025-02-27 19:42:15 -08:00
committed by GitHub
parent d38878523d
commit bc20e93f2d
5 changed files with 152 additions and 0 deletions

View File

@@ -53,6 +53,7 @@ from sglang.srt.managers.io_struct import (
ResumeMemoryOccupationReqInput,
UpdateWeightFromDiskReqInput,
UpdateWeightsFromDistributedReqInput,
VertexGenerateReqInput,
)
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.metrics.func_timer import enable_func_timer
@@ -475,6 +476,32 @@ async def sagemaker_chat_completions(raw_request: Request):
return await v1_chat_completions(_global_state.tokenizer_manager, raw_request)
## Vertex AI API
@app.post(os.environ.get("AIP_PREDICT_ROUTE", "/vertex_generate"))
async def vertex_generate(vertex_req: VertexGenerateReqInput, raw_request: Request):
if not vertex_req.instances:
return []
inputs = {}
for input_key in ("text", "input_ids", "input_embeds"):
if vertex_req.instances[0].get(input_key):
inputs[input_key] = [
instance.get(input_key) for instance in vertex_req.instances
]
break
image_data = [
instance.get("image_data")
for instance in vertex_req.instances
if instance.get("image_data") is not None
] or None
req = GenerateReqInput(
**inputs,
image_data=image_data,
**(vertex_req.parameters or {}),
)
ret = await generate_request(req, raw_request)
return ORJSONResponse({"predictions": ret})
def _create_error_response(e):
return ORJSONResponse(
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST

View File

@@ -568,3 +568,9 @@ class FunctionCallReqInput:
tool_call_parser: Optional[str] = (
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
)
@dataclass
class VertexGenerateReqInput:
instances: List[dict]
parameters: Optional[dict] = None