Add endpoint for file support, purely to speed up processing of input_embeds. (#2797)
This commit is contained in:
@@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi.
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
import multiprocessing as multiprocessing
|
import multiprocessing as multiprocessing
|
||||||
import os
|
import os
|
||||||
@@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|||||||
return _create_error_response(e)
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
|
@app.api_route("/generate_from_file", methods=["POST"])
|
||||||
|
async def generate_from_file_request(file: UploadFile, request: Request):
|
||||||
|
"""Handle a generate request, this is purely to work with input_embeds."""
|
||||||
|
content = await file.read()
|
||||||
|
input_embeds = json.loads(content.decode("utf-8"))
|
||||||
|
|
||||||
|
obj = GenerateReqInput(
|
||||||
|
input_embeds=input_embeds,
|
||||||
|
sampling_params={
|
||||||
|
"repetition_penalty": 1.2,
|
||||||
|
"temperature": 0.2,
|
||||||
|
"max_new_tokens": 512,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
ret = await _global_state.generate_request(obj, request).__anext__()
|
||||||
|
return ret
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Error: {e}")
|
||||||
|
return _create_error_response(e)
|
||||||
|
|
||||||
|
|
||||||
@app.api_route("/encode", methods=["POST", "PUT"])
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
||||||
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
async def encode_request(obj: EmbeddingReqInput, request: Request):
|
||||||
"""Handle an embedding request."""
|
"""Handle an embedding request."""
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
import json
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
@@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
|
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
|
||||||
|
|
||||||
def send_request(self, payload):
|
def send_request(self, payload):
|
||||||
"""Send a POST request to the API and return the response."""
|
"""Send a POST request to the /generate endpoint and return the response."""
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
self.base_url + "/generate",
|
self.base_url + "/generate",
|
||||||
json=payload,
|
json=payload,
|
||||||
@@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
"error": f"Request failed with status {response.status_code}: {response.text}"
|
"error": f"Request failed with status {response.status_code}: {response.text}"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def send_file_request(self, file_path):
|
||||||
|
"""Send a POST request to the /generate_from_file endpoint with a file."""
|
||||||
|
with open(file_path, "rb") as f:
|
||||||
|
response = requests.post(
|
||||||
|
self.base_url + "/generate_from_file",
|
||||||
|
files={"file": f},
|
||||||
|
timeout=30, # Set a reasonable timeout for the API request
|
||||||
|
)
|
||||||
|
if response.status_code == 200:
|
||||||
|
return response.json()
|
||||||
|
return {
|
||||||
|
"error": f"Request failed with status {response.status_code}: {response.text}"
|
||||||
|
}
|
||||||
|
|
||||||
def test_text_based_response(self):
|
def test_text_based_response(self):
|
||||||
"""Print API response using text-based input."""
|
"""Test and print API responses using text-based input."""
|
||||||
for text in self.texts:
|
for text in self.texts:
|
||||||
payload = {
|
payload = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
@@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_embedding_based_response(self):
|
def test_embedding_based_response(self):
|
||||||
"""Print API response using input embeddings."""
|
"""Test and print API responses using input embeddings."""
|
||||||
for text in self.texts:
|
for text in self.texts:
|
||||||
embeddings = self.generate_input_embeddings(text)
|
embeddings = self.generate_input_embeddings(text)
|
||||||
payload = {
|
payload = {
|
||||||
@@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def test_compare_text_vs_embedding(self):
|
def test_compare_text_vs_embedding(self):
|
||||||
"""Print responses for both text-based and embedding-based inputs."""
|
"""Test and compare responses for text-based and embedding-based inputs."""
|
||||||
for text in self.texts:
|
for text in self.texts:
|
||||||
# Text-based payload
|
# Text-based payload
|
||||||
text_payload = {
|
text_payload = {
|
||||||
@@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase):
|
|||||||
# This is flaky, so we skip this temporarily
|
# This is flaky, so we skip this temporarily
|
||||||
# self.assertEqual(text_response["text"], embed_response["text"])
|
# self.assertEqual(text_response["text"], embed_response["text"])
|
||||||
|
|
||||||
|
def test_generate_from_file(self):
|
||||||
|
"""Test the /generate_from_file endpoint using tokenized embeddings."""
|
||||||
|
for text in self.texts:
|
||||||
|
embeddings = self.generate_input_embeddings(text)
|
||||||
|
with tempfile.NamedTemporaryFile(
|
||||||
|
mode="w", suffix=".json", delete=False
|
||||||
|
) as tmp_file:
|
||||||
|
json.dump(embeddings, tmp_file)
|
||||||
|
tmp_file_path = tmp_file.name
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = self.send_file_request(tmp_file_path)
|
||||||
|
print(
|
||||||
|
f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}"
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
# Ensure the temporary file is deleted
|
||||||
|
os.remove(tmp_file_path)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def tearDownClass(cls):
|
def tearDownClass(cls):
|
||||||
kill_process_tree(cls.process.pid)
|
kill_process_tree(cls.process.pid)
|
||||||
|
|||||||
Reference in New Issue
Block a user