Files
sglang/test/srt/test_input_embeddings.py
2024-11-28 00:22:39 -08:00

115 lines
4.2 KiB
Python

import json
import unittest
import requests
from transformers import AutoModelForCausalLM, AutoTokenizer
from sglang.srt.utils import kill_process_tree
from sglang.test.test_utils import (
DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestInputEmbeds(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_SMALL_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model)
cls.ref_model = AutoModelForCausalLM.from_pretrained(cls.model)
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--disable-radix"],
)
cls.texts = [
"The capital of France is",
"What is the best time of year to visit Japan for cherry blossoms?",
]
def generate_input_embeddings(self, text):
"""Generate input embeddings for a given text."""
input_ids = self.tokenizer(text, return_tensors="pt")["input_ids"]
embeddings = self.ref_model.get_input_embeddings()(input_ids)
return embeddings.squeeze().tolist() # Convert tensor to a list for API use
def send_request(self, payload):
"""Send a POST request to the API and return the response."""
response = requests.post(
self.base_url + "/generate",
json=payload,
timeout=30, # Set a reasonable timeout for the API request
)
if response.status_code == 200:
return response.json()
return {
"error": f"Request failed with status {response.status_code}: {response.text}"
}
def test_text_based_response(self):
"""Print API response using text-based input."""
for text in self.texts:
payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Text Input: {text}\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_embedding_based_response(self):
"""Print API response using input embeddings."""
for text in self.texts:
embeddings = self.generate_input_embeddings(text)
payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
response = self.send_request(payload)
print(
f"Embeddings Input (for text '{text}'):\nResponse: {json.dumps(response, indent=2)}\n{'-' * 80}"
)
def test_compare_text_vs_embedding(self):
"""Print responses for both text-based and embedding-based inputs."""
for text in self.texts:
# Text-based payload
text_payload = {
"model": self.model,
"text": text,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Embedding-based payload
embeddings = self.generate_input_embeddings(text)
embed_payload = {
"model": self.model,
"input_embeds": embeddings,
"sampling_params": {"temperature": 0, "max_new_tokens": 50},
}
# Get responses
text_response = self.send_request(text_payload)
embed_response = self.send_request(embed_payload)
# Print responses
print(
f"Text Input: {text}\nText-Based Response: {json.dumps(text_response, indent=2)}\n"
)
print(
f"Embeddings Input (for text '{text}'):\nEmbedding-Based Response: {json.dumps(embed_response, indent=2)}\n{'-' * 80}"
)
self.assertEqual(text_response["text"], embed_response["text"])
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
if __name__ == "__main__":
unittest.main()