From c193002297d18efeacbc0887ec1c3a4c7b2c039e Mon Sep 17 00:00:00 2001 From: Aidan Cooper <30752032+AidanCooper@users.noreply.github.com> Date: Tue, 16 Jul 2024 19:54:42 +0100 Subject: [PATCH] Add support for VertexAI safety settings (#624) --- python/sglang/backend/vertexai.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python/sglang/backend/vertexai.py b/python/sglang/backend/vertexai.py index f32fca2f4..07827b76d 100644 --- a/python/sglang/backend/vertexai.py +++ b/python/sglang/backend/vertexai.py @@ -1,8 +1,6 @@ import os import warnings -from typing import List, Optional, Union - -import numpy as np +from typing import Optional from sglang.backend.base_backend import BaseBackend from sglang.lang.chat_template import get_chat_template @@ -16,12 +14,15 @@ try: GenerativeModel, Image, ) + from vertexai.generative_models._generative_models import SafetySettingsType except ImportError as e: GenerativeModel = e class VertexAI(BaseBackend): - def __init__(self, model_name): + def __init__( + self, model_name, safety_settings: Optional[SafetySettingsType] = None + ): super().__init__() if isinstance(GenerativeModel, Exception): @@ -33,6 +34,7 @@ class VertexAI(BaseBackend): self.model_name = model_name self.chat_template = get_chat_template("default") + self.safety_settings = safety_settings def get_chat_template(self): return self.chat_template @@ -54,6 +56,7 @@ class VertexAI(BaseBackend): ret = GenerativeModel(self.model_name).generate_content( prompt, generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, ) comp = ret.text @@ -78,6 +81,7 @@ class VertexAI(BaseBackend): prompt, stream=True, generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()), + safety_settings=self.safety_settings, ) for ret in generator: yield ret.text, {}