diff --git a/python/sglang/srt/managers/schedule_policy.py b/python/sglang/srt/managers/schedule_policy.py index 60633552b..fe309e3d8 100644 --- a/python/sglang/srt/managers/schedule_policy.py +++ b/python/sglang/srt/managers/schedule_policy.py @@ -583,6 +583,7 @@ class PrefillAdder: req.prefix_indices = torch.cat([req.prefix_indices, new_indices]) req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices) prefix_len = len(req.prefix_indices) + req.last_matched_prefix_len = prefix_len input_tokens = self.ceil_paged_tokens(req.extend_input_len) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 38a241640..e1d558c1b 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -762,6 +762,7 @@ class Scheduler( hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy, model_name=server_args.served_model_name, storage_backend_extra_config=server_args.hicache_storage_backend_extra_config, + is_eagle=self.spec_algorithm.is_eagle(), ) self.tp_worker.register_hicache_layer_transfer_counter( self.tree_cache.cache_controller.layer_done_counter diff --git a/python/sglang/srt/mem_cache/hiradix_cache.py b/python/sglang/srt/mem_cache/hiradix_cache.py index c3d6342d9..1a8b6accc 100644 --- a/python/sglang/srt/mem_cache/hiradix_cache.py +++ b/python/sglang/srt/mem_cache/hiradix_cache.py @@ -44,6 +44,7 @@ class HiRadixCache(RadixCache): hicache_storage_prefetch_policy: Optional[str] = "best_effort", model_name: Optional[str] = None, storage_backend_extra_config: Optional[str] = None, + is_eagle: bool = False, ): if hicache_io_backend == "direct": @@ -135,6 +136,7 @@ class HiRadixCache(RadixCache): page_size, disable=False, eviction_policy=eviction_policy, + is_eagle=is_eagle, ) def _parse_storage_backend_extra_config( @@ -658,6 +660,7 @@ class HiRadixCache(RadixCache): def match_prefix(self, key: RadixKey, **kwargs): empty_value = torch.empty((0,), dtype=torch.int64, device=self.device) + key.token_ids = self.key_convert_fn(key.token_ids) if self.disable or len(key) == 0: return MatchResult( device_indices=empty_value, @@ -820,9 +823,15 @@ class HiRadixCache(RadixCache): return new_node def insert(self, key: RadixKey, value=None, chunked=False): + key.token_ids = self.key_convert_fn(key.token_ids) + if len(key) == 0: return 0 + if self.is_eagle and value is not None: + # Make sure the value len equal to the EAGLE bigram key len + value = value[: len(key)] + node = self.root_node child_key = self.get_child_key_fn(key) total_prefix_length = 0 diff --git a/test/srt/hicache/test_hicache_eagle.py b/test/srt/hicache/test_hicache_eagle.py new file mode 100644 index 000000000..f6265b9c1 --- /dev/null +++ b/test/srt/hicache/test_hicache_eagle.py @@ -0,0 +1,78 @@ +import unittest +from types import SimpleNamespace + +import requests + +from sglang.bench_serving import get_tokenizer +from sglang.srt.utils import kill_process_tree +from sglang.test.run_eval import run_eval +from sglang.test.test_utils import ( + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + + +class TestHiCacheEagle(CustomTestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 + cls.base_url = DEFAULT_URL_FOR_TEST + cls.tokenizer = get_tokenizer(cls.model) + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=[ + "--enable-hierarchical-cache", + "--hicache-ratio", + 1.2, + "--mem-fraction-static", + 0.7, + "--speculative-algorithm", + "EAGLE3", + "--speculative-draft-model-path", + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + "--speculative-num-steps", + 2, + "--speculative-eagle-topk", + 1, + "--speculative-num-draft-tokens", + 3, + "--dtype", + "float16", + "--chunked-prefill-size", + 1024, + ], + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_mmlu(self): + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=64, + num_threads=32, + ) + + metrics = run_eval(args) + self.assertGreaterEqual(metrics["score"], 0.72) + + server_info = requests.get(self.base_url + "/get_server_info") + print(f"{server_info=}") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, 2.26) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 5ffe6b62d..08900cae4 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -17,6 +17,7 @@ suites = { TestFile("hicache/test_hicache.py", 116), TestFile("hicache/test_hicache_mla.py", 127), TestFile("hicache/test_hicache_storage.py", 127), + TestFile("hicache/test_hicache_eagle.py", 150), TestFile("lora/test_lora.py", 200), TestFile("lora/test_lora_eviction.py", 200), TestFile("lora/test_lora_backend.py", 99),