Add support for VertexAI safety settings (#624)
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
import os
|
||||
import warnings
|
||||
from typing import List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
from typing import Optional
|
||||
|
||||
from sglang.backend.base_backend import BaseBackend
|
||||
from sglang.lang.chat_template import get_chat_template
|
||||
@@ -16,12 +14,15 @@ try:
|
||||
GenerativeModel,
|
||||
Image,
|
||||
)
|
||||
from vertexai.generative_models._generative_models import SafetySettingsType
|
||||
except ImportError as e:
|
||||
GenerativeModel = e
|
||||
|
||||
|
||||
class VertexAI(BaseBackend):
|
||||
def __init__(self, model_name):
|
||||
def __init__(
|
||||
self, model_name, safety_settings: Optional[SafetySettingsType] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if isinstance(GenerativeModel, Exception):
|
||||
@@ -33,6 +34,7 @@ class VertexAI(BaseBackend):
|
||||
|
||||
self.model_name = model_name
|
||||
self.chat_template = get_chat_template("default")
|
||||
self.safety_settings = safety_settings
|
||||
|
||||
def get_chat_template(self):
|
||||
return self.chat_template
|
||||
@@ -54,6 +56,7 @@ class VertexAI(BaseBackend):
|
||||
ret = GenerativeModel(self.model_name).generate_content(
|
||||
prompt,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
|
||||
comp = ret.text
|
||||
@@ -78,6 +81,7 @@ class VertexAI(BaseBackend):
|
||||
prompt,
|
||||
stream=True,
|
||||
generation_config=GenerationConfig(**sampling_params.to_vertexai_kwargs()),
|
||||
safety_settings=self.safety_settings,
|
||||
)
|
||||
for ret in generator:
|
||||
yield ret.text, {}
|
||||
|
||||
Reference in New Issue
Block a user