Add support for VertexAI safety settings (#624)
This commit is contained in:
@@ -1,8 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from typing import List, Optional, Union
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
|
|
||||||
from sglang.backend.base_backend import BaseBackend
|
from sglang.backend.base_backend import BaseBackend
|
||||||
from sglang.lang.chat_template import get_chat_template
|
from sglang.lang.chat_template import get_chat_template
|
||||||
@@ -16,12 +14,15 @@ try:
|
|||||||
GenerativeModel,
|
GenerativeModel,
|
||||||
Image,
|
Image,
|
||||||
)
|
)
|
||||||
|
from vertexai.generative_models._generative_models import SafetySettingsType
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
GenerativeModel = e
|
GenerativeModel = e
|
||||||
|
|
||||||
|
|
||||||
class VertexAI(BaseBackend):
|
class VertexAI(BaseBackend):
|
||||||
def __init__(self, model_name):
|
def __init__(
|
||||||
|
self, model_name, safety_settings: Optional[SafetySettingsType] = None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if isinstance(GenerativeModel, Exception):
|
if isinstance(GenerativeModel, Exception):
|
||||||
@@ -33,6 +34,7 @@ class VertexAI(BaseBackend):
|
|||||||
|
|
||||||
self.model_name = model_name
|
self.model_name = model_name
|
||||||
self.chat_template = get_chat_template("default")
|
self.chat_template = get_chat_template("default")
|
||||||
|
self.safety_settings = safety_settings
|
||||||
|
|
||||||
def get_chat_template(self):
|
def get_chat_template(self):
|
||||||
return self.chat_template
|
return self.chat_template
|
||||||
@@ -54,6 +56,7 @@ class VertexAI(BaseBackend):
|
|||||||
ret = GenerativeModel(self.model_name).generate_content(
|
ret = GenerativeModel(self.model_name).generate_content(
|
||||||
prompt,
|
prompt,
|
||||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||||
|
safety_settings=self.safety_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
comp = ret.text
|
comp = ret.text
|
||||||
@@ -78,6 +81,7 @@ class VertexAI(BaseBackend):
|
|||||||
prompt,
|
prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||||
|
safety_settings=self.safety_settings,
|
||||||
)
|
)
|
||||||
for ret in generator:
|
for ret in generator:
|
||||||
yield ret.text, {}
|
yield ret.text, {}
|
||||||
|
|||||||
Reference in New Issue
Block a user