Add a test case for cached_tokens (#3145)
This commit is contained in:
@@ -18,7 +18,6 @@ suites = {
|
||||
"test_eagle_infer.py",
|
||||
"test_embedding_openai_server.py",
|
||||
"test_eval_accuracy_mini.py",
|
||||
"test_get_weights_by_name.py",
|
||||
"test_gguf.py",
|
||||
"test_input_embeddings.py",
|
||||
"test_json_constrained.py",
|
||||
|
||||
@@ -236,12 +236,5 @@ class TestEBNFConstrained(unittest.TestCase):
|
||||
)
|
||||
|
||||
|
||||
class TestJumpForward(TestEBNFConstrained):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
setup_class(cls, disable_overlap=True)
|
||||
cls.check_jump_forward = True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -5,6 +5,7 @@ python3 -m unittest test_srt_endpoint.TestSRTEndpoint.test_logprob_with_chunked_
|
||||
|
||||
import json
|
||||
import random
|
||||
import time
|
||||
import unittest
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
@@ -317,12 +318,6 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
"""Test custom logit processor with a single request."""
|
||||
self.run_custom_logit_processor(target_token_id=5)
|
||||
|
||||
def test_custom_logit_processor_batch(self):
|
||||
"""Test custom logit processor with a batch of requests."""
|
||||
target_token_ids = list(range(32))
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_custom_logit_processor_batch_mixed(self):
|
||||
"""Test a batch of requests mixed of requests with and without custom logit processor."""
|
||||
target_token_ids = list(range(32)) + [None] * 16
|
||||
@@ -330,6 +325,31 @@ class TestSRTEndpoint(unittest.TestCase):
|
||||
with ThreadPoolExecutor(len(target_token_ids)) as executor:
|
||||
list(executor.map(self.run_custom_logit_processor, target_token_ids))
|
||||
|
||||
def test_cache_tokens(self):
|
||||
for _ in range(2):
|
||||
time.sleep(1)
|
||||
response = requests.post(self.base_url + "/flush_cache")
|
||||
assert response.status_code == 200
|
||||
|
||||
def send_and_check_cached_tokens(input_ids):
|
||||
response = requests.post(
|
||||
self.base_url + "/generate",
|
||||
json={
|
||||
"input_ids": list(input_ids),
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 1,
|
||||
},
|
||||
},
|
||||
)
|
||||
response_json = response.json()
|
||||
return response_json["meta_info"]["cached_tokens"]
|
||||
|
||||
self.assertEqual(send_and_check_cached_tokens(range(0, 100)), 0)
|
||||
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 100)
|
||||
self.assertEqual(send_and_check_cached_tokens(range(0, 10000)), 9999)
|
||||
self.assertEqual(send_and_check_cached_tokens(range(0, 1000)), 999)
|
||||
self.assertEqual(send_and_check_cached_tokens(range(0, 11000)), 10000)
|
||||
|
||||
def test_get_server_info(self):
|
||||
response = requests.get(self.base_url + "/get_server_info")
|
||||
response_json = response.json()
|
||||
|
||||
Reference in New Issue
Block a user