Make API Key OpenAI-compatible (#917)

This commit is contained in:
Ying Sheng
2024-08-04 13:35:44 -07:00
committed by GitHub
parent afd411d09f
commit 0d4f3a9fcd
7 changed files with 115 additions and 125 deletions

View File

@@ -13,7 +13,10 @@ class TestOpenAIServer(unittest.TestCase):
def setUpClass(cls):
cls.model = MODEL_NAME_FOR_TEST
cls.base_url = f"http://localhost:30000"
cls.process = popen_launch_server(cls.model, cls.base_url, timeout=300)
cls.api_key = "sk-123456"
cls.process = popen_launch_server(
cls.model, cls.base_url, timeout=300, api_key=cls.api_key
)
cls.base_url += "/v1"
@classmethod
@@ -21,7 +24,7 @@ class TestOpenAIServer(unittest.TestCase):
kill_child_process(cls.process.pid)
def run_completion(self, echo, logprobs, use_list_input):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
if use_list_input:
@@ -63,7 +66,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0
def run_completion_stream(self, echo, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
prompt = "The capital of France is"
generator = client.completions.create(
model=self.model,
@@ -102,7 +105,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0
def run_chat_completion(self, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
response = client.chat.completions.create(
model=self.model,
messages=[
@@ -135,7 +138,7 @@ class TestOpenAIServer(unittest.TestCase):
assert response.usage.total_tokens > 0
def run_chat_completion_stream(self, logprobs):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
generator = client.chat.completions.create(
model=self.model,
messages=[
@@ -186,7 +189,7 @@ class TestOpenAIServer(unittest.TestCase):
self.run_chat_completion_stream(logprobs)
def test_regex(self):
client = openai.Client(api_key="EMPTY", base_url=self.base_url)
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
regex = (
r"""\{\n"""