diff --git a/python/sglang/srt/entrypoints/http_server.py b/python/sglang/srt/entrypoints/http_server.py index 32a11c15c..5a97072de 100644 --- a/python/sglang/srt/entrypoints/http_server.py +++ b/python/sglang/srt/entrypoints/http_server.py @@ -19,6 +19,7 @@ This file implements HTTP APIs for the inference engine via fastapi. import asyncio import dataclasses +import json import logging import multiprocessing as multiprocessing import os @@ -259,6 +260,29 @@ async def generate_request(obj: GenerateReqInput, request: Request): return _create_error_response(e) +@app.api_route("/generate_from_file", methods=["POST"]) +async def generate_from_file_request(file: UploadFile, request: Request): + """Handle a generate request, this is purely to work with input_embeds.""" + content = await file.read() + input_embeds = json.loads(content.decode("utf-8")) + + obj = GenerateReqInput( + input_embeds=input_embeds, + sampling_params={ + "repetition_penalty": 1.2, + "temperature": 0.2, + "max_new_tokens": 512, + }, + ) + + try: + ret = await _global_state.generate_request(obj, request).__anext__() + return ret + except ValueError as e: + logger.error(f"Error: {e}") + return _create_error_response(e) + + @app.api_route("/encode", methods=["POST", "PUT"]) async def encode_request(obj: EmbeddingReqInput, request: Request): """Handle an embedding request.""" diff --git a/test/srt/test_input_embeddings.py b/test/srt/test_input_embeddings.py index 015aabe78..92b643fd3 100644 --- a/test/srt/test_input_embeddings.py +++ b/test/srt/test_input_embeddings.py @@ -1,4 +1,6 @@ import json +import os +import tempfile import unittest import requests @@ -38,7 +40,7 @@ class TestInputEmbeds(unittest.TestCase): return embeddings.squeeze().tolist() # Convert tensor to a list for API use def send_request(self, payload): - """Send a POST request to the API and return the response.""" + """Send a POST request to the /generate endpoint and return the response.""" response = requests.post( self.base_url + "/generate", json=payload, @@ -50,8 +52,22 @@ class TestInputEmbeds(unittest.TestCase): "error": f"Request failed with status {response.status_code}: {response.text}" } + def send_file_request(self, file_path): + """Send a POST request to the /generate_from_file endpoint with a file.""" + with open(file_path, "rb") as f: + response = requests.post( + self.base_url + "/generate_from_file", + files={"file": f}, + timeout=30, # Set a reasonable timeout for the API request + ) + if response.status_code == 200: + return response.json() + return { + "error": f"Request failed with status {response.status_code}: {response.text}" + } + def test_text_based_response(self): - """Print API response using text-based input.""" + """Test and print API responses using text-based input.""" for text in self.texts: payload = { "model": self.model, @@ -64,7 +80,7 @@ class TestInputEmbeds(unittest.TestCase): ) def test_embedding_based_response(self): - """Print API response using input embeddings.""" + """Test and print API responses using input embeddings.""" for text in self.texts: embeddings = self.generate_input_embeddings(text) payload = { @@ -78,7 +94,7 @@ class TestInputEmbeds(unittest.TestCase): ) def test_compare_text_vs_embedding(self): - """Print responses for both text-based and embedding-based inputs.""" + """Test and compare responses for text-based and embedding-based inputs.""" for text in self.texts: # Text-based payload text_payload = { @@ -106,6 +122,25 @@ class TestInputEmbeds(unittest.TestCase): # This is flaky, so we skip this temporarily # self.assertEqual(text_response["text"], embed_response["text"]) + def test_generate_from_file(self): + """Test the /generate_from_file endpoint using tokenized embeddings.""" + for text in self.texts: + embeddings = self.generate_input_embeddings(text) + with tempfile.NamedTemporaryFile( + mode="w", suffix=".json", delete=False + ) as tmp_file: + json.dump(embeddings, tmp_file) + tmp_file_path = tmp_file.name + + try: + response = self.send_file_request(tmp_file_path) + print( + f"Text Input: {text}\nResponse from /generate_from_file: {json.dumps(response, indent=2)}\n{'-' * 80}" + ) + finally: + # Ensure the temporary file is deleted + os.remove(tmp_file_path) + @classmethod def tearDownClass(cls): kill_process_tree(cls.process.pid)