Add endpoint for file support, purely to speed up processing of input_embeds. (#2797)

This commit is contained in:
Rin Intachuen
2025-03-17 08:30:37 +07:00
committed by GitHub
parent 48efec7b05
commit d1112d8548
2 changed files with 63 additions and 4 deletions

View File

@@ -1,4 +1,6 @@
import json
import os
import tempfile
import unittest
import requests
@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase):
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
def send_request(self, payload):
"""Send a POST request to the API and return the response."""
"""Send a POST request to the /generate endpoint and return the response."""
response = requests.post(
self.base_url + "/generate",
json=payload,
@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase):
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def send_file_request(self, file_path):
"""Send a POST request to the /generate_from_file endpoint with a file."""
with open(file_path, "rb") as f:
response = requests.post(
self.base_url + "/generate_from_file",
files={"file": f},
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def test_text_based_response(self):
"""Print API response using text-based input."""
"""Test and print API responses using text-based input."""
for text in self.texts:
payload = {
"model": self.model,
@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase):
)
def test_embedding_based_response(self):
"""Print API response using input embeddings."""
"""Test and print API responses using input embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
payload = {
@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase):
)
def test_compare_text_vs_embedding(self):
"""Print responses for both text-based and embedding-based inputs."""
"""Test and compare responses for text-based and embedding-based inputs."""
for text in self.texts:
# Text-based payload
text_payload = {
@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase):
# This is flaky, so we skip this temporarily
# self.assertEqual(text_response["text"], embed_response["text"])
def test_generate_from_file(self):
"""Test the /generate_from_file endpoint using tokenized embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".json", delete=False
) as tmp_file:
json.dump(embeddings, tmp_file)
tmp_file_path = tmp_file.name
try:
response = self.send_file_request(tmp_file_path)
print(
f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
finally:
# Ensure the temporary file is deleted
os.remove(tmp_file_path)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)