Add a test case for cached_tokens (#3145)

This commit is contained in:
Lianmin Zheng
2025-01-26 01:39:28 -08:00
committed by GitHub
parent f8b28e461a
commit d1a0863251
6 changed files with 74 additions and 63 deletions

View File

@@ -18,7 +18,6 @@ suites = {
"test_eagle_infer.py",
"test_embedding_openai_server.py",
"test_eval_accuracy_mini.py",
"test_get_weights_by_name.py",
"test_gguf.py",
"test_input_embeddings.py",
"test_json_constrained.py",

View File

@@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase):
)
class TestJumpForward(TestEBNFConstrained):
@classmethod
def setUpClass(cls):
setup_class(cls, disable_overlap=True)
cls.check_jump_forward = True
if __name__ == "__main__":
unittest.main()

View File

@@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
import json
import random
import time
import unittest
from concurrent.futures import ThreadPoolExecutor
from typing import Optional
@@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase):
"""Test custom logit processor with a single request."""
self.run_custom_logit_processor(target_token_id=5)
def test_custom_logit_processor_batch(self):
"""Test custom logit processor with a batch of requests."""
target_token_ids = list(range(32))
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_custom_logit_processor_batch_mixed(self):
"""Test a batch of requests mixed of requests with and without custom logit processor."""
target_token_ids = list(range(32)) + [None] * 16
@@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase):
with ThreadPoolExecutor(len(target_token_ids)) as executor:
list(executor.map(self.run_custom_logit_processor, target_token_ids))
def test_cache_tokens(self):
for _ in range(2):
time.sleep(1)
response = requests.post(self.base_url + "/flush_cache")
assert response.status_code == 200
def send_and_check_cached_tokens(input_ids):
response = requests.post(
self.base_url + "/generate",
json={
"input_ids": list(input_ids),
"sampling_params": {
"max_new_tokens": 1,
},
},
)
response_json = response.json()
return response_json["meta_info"]["cached_tokens"]
self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)
def test_get_server_info(self):
response = requests.get(self.base_url + "/get_server_info")
response_json = response.json()