Fix eagle radix cache (#10846)
This commit is contained in:
@@ -9,6 +9,8 @@ from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
@@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
}
|
||||
NUM_CONFIGS = 2
|
||||
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.9,
|
||||
"accept_len": 3.6,
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
self.prompt = "Today is a sunny day and I like"
|
||||
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
@@ -63,6 +70,7 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
self._test_eos_token(engine)
|
||||
self._test_acc_length(engine)
|
||||
finally:
|
||||
engine.flush_cache() # check engine alive
|
||||
engine.shutdown()
|
||||
print("=" * 100)
|
||||
|
||||
@@ -92,7 +100,9 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
"avg_spec_accept_length"
|
||||
]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 1.9)
|
||||
self.assertGreater(
|
||||
avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"]
|
||||
)
|
||||
|
||||
def _test_eos_token(self, engine):
|
||||
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
|
||||
@@ -131,10 +141,7 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
)
|
||||
print(f"{acc_length=:.4f}, {speed=}")
|
||||
|
||||
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
||||
self.assertGreater(acc_length, 3.6)
|
||||
else:
|
||||
self.assertGreater(acc_length, 2.5)
|
||||
self.assertGreater(acc_length, self.THRESHOLDS["accept_len"])
|
||||
|
||||
|
||||
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||
@@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||
"dtype": "float16",
|
||||
}
|
||||
NUM_CONFIGS = 1
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.9,
|
||||
"accept_len": 2.5,
|
||||
}
|
||||
|
||||
|
||||
class TestEAGLE3Engine(TestEAGLEEngine):
|
||||
BASE_CONFIG = {
|
||||
"model_path": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
"speculative_algorithm": "EAGLE3",
|
||||
"speculative_num_steps": 5,
|
||||
"speculative_eagle_topk": 16,
|
||||
@@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine):
|
||||
"dtype": "float16",
|
||||
}
|
||||
NUM_CONFIGS = 1
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.75,
|
||||
"accept_len": 3.1,
|
||||
}
|
||||
|
||||
|
||||
class TestEAGLERadixCache(CustomTestCase):
|
||||
BASE_CONFIG = {
|
||||
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
"speculative_algorithm": "EAGLE3",
|
||||
"speculative_num_steps": 2,
|
||||
"speculative_eagle_topk": 1,
|
||||
"speculative_num_draft_tokens": 3,
|
||||
"mem_fraction_static": 0.7,
|
||||
"cuda_graph_max_bs": 5,
|
||||
"dtype": "float16",
|
||||
}
|
||||
|
||||
def test_correctness(self):
|
||||
configs = [
|
||||
# Basic config
|
||||
self.BASE_CONFIG,
|
||||
# Chunked prefill
|
||||
{**self.BASE_CONFIG, "chunked_prefill_size": 64},
|
||||
# Chunked prefill & Page Size > 1
|
||||
{**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4},
|
||||
]
|
||||
|
||||
for i, config in enumerate(configs):
|
||||
with self.subTest(i=i):
|
||||
print(f"{config=}")
|
||||
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
|
||||
try:
|
||||
self._test_acc_length(engine)
|
||||
finally:
|
||||
engine.shutdown()
|
||||
print("=" * 100)
|
||||
|
||||
def _test_acc_length(self, engine):
|
||||
warmup_prompt = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||
]
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||
output = engine.generate(warmup_prompt, sampling_params)
|
||||
test_prompt = [
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
]
|
||||
output = engine.generate(test_prompt, sampling_params)
|
||||
output = output[0]
|
||||
|
||||
if "spec_verify_ct" in output["meta_info"]:
|
||||
acc_length = (
|
||||
output["meta_info"]["completion_tokens"]
|
||||
/ output["meta_info"]["spec_verify_ct"]
|
||||
)
|
||||
else:
|
||||
acc_length = 1.0
|
||||
|
||||
speed = (
|
||||
output["meta_info"]["completion_tokens"]
|
||||
/ output["meta_info"]["e2e_latency"]
|
||||
)
|
||||
print(f"{acc_length=:.4f}, {speed=}")
|
||||
|
||||
self.assertGreater(acc_length, 2.5)
|
||||
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||
|
||||
@@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase):
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle(self):
|
||||
"""Test insert and match operations for EAGLE."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3, 4])
|
||||
value = torch.tensor([10, 20, 30, 40], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(
|
||||
cache.total_size(), 3
|
||||
) # The last token is ignored in bigram key
|
||||
self.assertEqual(cache.evictable_size(), 3)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test partial match
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 1)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle_page_size(self):
|
||||
"""Test insert and match operations for EAGLE and page_size > 1."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=2,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3])
|
||||
value = torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(cache.total_size(), 2) # only one page is inserted
|
||||
self.assertEqual(cache.evictable_size(), 2)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 2)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test unmatched
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 0)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_with_none_value(self):
|
||||
"""Test insert with None value (should use token_ids as list)."""
|
||||
cache = RadixCache(
|
||||
|
||||
Reference in New Issue
Block a user