From 91847e382acbfdec0a40a58918e47bbd2191ef6a Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Tue, 30 Sep 2025 22:59:20 +0800 Subject: [PATCH] Fix eagle radix cache (#10846) --- python/sglang/srt/managers/schedule_batch.py | 3 + python/sglang/srt/managers/scheduler.py | 1 + python/sglang/srt/mem_cache/radix_cache.py | 97 ++++++++++++++++---- python/sglang/test/test_utils.py | 3 +- test/srt/test_eagle_infer_a.py | 91 ++++++++++++++++-- test/srt/test_radix_cache_unit.py | 66 +++++++++++++ 6 files changed, 235 insertions(+), 26 deletions(-) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 97f5d00ff..ff9edc58b 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -547,6 +547,8 @@ class Req: self.host_hit_length = 0 # The node to lock until for swa radix tree lock ref self.swa_uuid_for_lock: Optional[int] = None + # The prefix length of the last prefix matching + self.last_matched_prefix_len: int = 0 # Whether or not if it is chunked. It increments whenever # it is chunked, and decrement whenever chunked request is @@ -701,6 +703,7 @@ class Req: token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key ), ) + self.last_matched_prefix_len = len(self.prefix_indices) self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices) def adjust_max_prefix_ids(self): diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 893a0b0a1..d62c7f01c 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -756,6 +756,7 @@ class Scheduler( disable=server_args.disable_radix_cache, enable_kv_cache_events=self.enable_kv_cache_events, eviction_policy=server_args.radix_eviction_policy, + is_eagle=self.spec_algorithm.is_eagle(), ) if ( diff --git a/python/sglang/srt/mem_cache/radix_cache.py b/python/sglang/srt/mem_cache/radix_cache.py index abb9445f8..2f818770a 100644 --- a/python/sglang/srt/mem_cache/radix_cache.py +++ b/python/sglang/srt/mem_cache/radix_cache.py @@ -23,7 +23,7 @@ import heapq import time from collections import defaultdict from functools import partial -from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union +from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union import torch @@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1): return (key.extra_key, plain_key) +def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]: + # EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target + # [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)] + if len(tokens) < 2: + return [] + if isinstance(tokens[0], tuple): + return tokens + return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)] + + class RadixCache(BasePrefixCache): def __init__( self, @@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache): disable: bool = False, enable_kv_cache_events: bool = False, eviction_policy: str = "lru", + is_eagle: bool = False, ): self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator @@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache): self.disable = disable self.enable_kv_cache_events = enable_kv_cache_events self.kv_event_queue = [] + self.is_eagle = is_eagle if self.token_to_kv_pool_allocator: self.device = self.token_to_kv_pool_allocator.device @@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache): self.key_match_fn = partial(_key_match_paged, page_size=page_size) self.get_child_key_fn = partial(get_child_key, page_size=page_size) + if is_eagle: + self.key_convert_fn = _convert_to_bigram_key + else: + self.key_convert_fn = lambda key: key + if eviction_policy.lower() == "lru": self.eviction_strategy: EvictionStrategy = LRUStrategy() elif eviction_policy.lower() == "lfu": @@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache): to expose a precise boundary; this structural refinement improves subsequent match efficiency and does not duplicate data. """ + key.token_ids = self.key_convert_fn(key.token_ids) + if self.disable or len(key) == 0: return MatchResult( device_indices=torch.empty( @@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache): if self.disable: return 0 + key.token_ids = self.key_convert_fn(key.token_ids) + if value is None: value = torch.tensor(key.token_ids, dtype=torch.int64) + + if self.is_eagle: + # Make sure the value len equal to the EAGLE bigram key len + value = value[: len(key)] + return self._insert_helper(self.root_node, key, value) def cache_finished_req(self, req: Req): @@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache): return token_ids = (req.origin_input_ids + req.output_ids)[:-1] + all_token_len = len(token_ids) + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) + if self.is_eagle: + self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:]) + + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( - RadixKey(token_ids[:page_aligned_len], req.extra_key), + RadixKey(token_ids[:page_aligned_token_len], req.extra_key), page_aligned_kv_indices, ) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) # Remove req slot release the cache lock self.req_to_token_pool.free(req.req_pool_idx) @@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache): return token_ids = req.fill_ids + all_token_len = len(token_ids) + # The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key + actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len kv_indices = self.req_to_token_pool.req_to_token[ - req.req_pool_idx, : len(token_ids) + req.req_pool_idx, :all_token_len ] if self.page_size != 1: - page_aligned_len = len(kv_indices) // self.page_size * self.page_size + page_aligned_len = actual_kv_len // self.page_size * self.page_size page_aligned_kv_indices = kv_indices[:page_aligned_len].to( dtype=torch.int64, copy=True ) else: - page_aligned_len = len(kv_indices) + page_aligned_len = actual_kv_len page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True) - page_aligned_token_ids = token_ids[:page_aligned_len] + + # For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1 + page_aligned_token_len = ( + page_aligned_len + 1 if self.is_eagle else page_aligned_len + ) + page_aligned_token_ids = token_ids[:page_aligned_token_len] + + old_prefix_len = len(req.prefix_indices) + if self.is_eagle and old_prefix_len > req.last_matched_prefix_len: + # prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE) + old_prefix_len -= 1 # Radix Cache takes one ref in memory pool new_prefix_len = self.insert( @@ -346,29 +396,40 @@ class RadixCache(BasePrefixCache): page_aligned_kv_indices, chunked=chunked, ) - self.token_to_kv_pool_allocator.free( - kv_indices[len(req.prefix_indices) : new_prefix_len] - ) + self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len]) # The prefix indices could be updated, reuse it new_indices, new_last_node, _, _ = self.match_prefix( RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key) ) self.req_to_token_pool.write( - (req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))), - new_indices[len(req.prefix_indices) :], + (req.req_pool_idx, slice(old_prefix_len, len(new_indices))), + new_indices[old_prefix_len:], ) + # The last_matched_prefix_len is not always equal to len(req.prefix_indices) + # since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree. + # It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak. + # So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly. + req.last_matched_prefix_len = len(new_indices) + self.dec_lock_ref(req.last_node) self.inc_lock_ref(new_last_node) # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later if self.page_size != 1: + # Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req. req.prefix_indices = torch.cat( [new_indices, kv_indices[len(new_indices) :]] ) else: - req.prefix_indices = new_indices + if self.is_eagle: + # Attach the kv index of the last token for EAGLE, it can be used in chunked prefill + req.prefix_indices = torch.cat( + [new_indices, kv_indices[actual_kv_len:]] + ) + else: + req.prefix_indices = new_indices req.last_node = new_last_node def pretty_print(self): diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 2e9a16896..360c852cb 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -77,7 +77,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8" # EAGLE DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf" DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B" -DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B" +DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct" +DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B" DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( "meta-llama/Llama-3.1-8B-Instruct" ) diff --git a/test/srt/test_eagle_infer_a.py b/test/srt/test_eagle_infer_a.py index c19f0c22f..f956059c0 100644 --- a/test/srt/test_eagle_infer_a.py +++ b/test/srt/test_eagle_infer_a.py @@ -9,6 +9,8 @@ from sglang.srt.utils import kill_process_tree from sglang.test.test_utils import ( DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST, DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST, + DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, DEFAULT_MODEL_NAME_FOR_TEST_MLA, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, @@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase): } NUM_CONFIGS = 2 + THRESHOLDS = { + "batch_avg_accept_len": 1.9, + "accept_len": 3.6, + } + def setUp(self): self.prompt = "Today is a sunny day and I like" self.sampling_params = {"temperature": 0, "max_new_tokens": 8} @@ -63,6 +70,7 @@ class TestEAGLEEngine(CustomTestCase): self._test_eos_token(engine) self._test_acc_length(engine) finally: + engine.flush_cache() # check engine alive engine.shutdown() print("=" * 100) @@ -92,7 +100,9 @@ class TestEAGLEEngine(CustomTestCase): "avg_spec_accept_length" ] print(f"{avg_spec_accept_length=}") - self.assertGreater(avg_spec_accept_length, 1.9) + self.assertGreater( + avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"] + ) def _test_eos_token(self, engine): prompt = "[INST] <>\nYou are a helpful assistant.\n<>\nToday is a sunny day and I like [/INST]" @@ -131,10 +141,7 @@ class TestEAGLEEngine(CustomTestCase): ) print(f"{acc_length=:.4f}, {speed=}") - if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST: - self.assertGreater(acc_length, 3.6) - else: - self.assertGreater(acc_length, 2.5) + self.assertGreater(acc_length, self.THRESHOLDS["accept_len"]) class TestEAGLEEngineTokenMap(TestEAGLEEngine): @@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine): "dtype": "float16", } NUM_CONFIGS = 1 + THRESHOLDS = { + "batch_avg_accept_len": 1.9, + "accept_len": 2.5, + } class TestEAGLE3Engine(TestEAGLEEngine): BASE_CONFIG = { - "model_path": "meta-llama/Llama-3.1-8B-Instruct", - "speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B", + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + "speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, "speculative_algorithm": "EAGLE3", "speculative_num_steps": 5, "speculative_eagle_topk": 16, @@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine): "dtype": "float16", } NUM_CONFIGS = 1 + THRESHOLDS = { + "batch_avg_accept_len": 1.75, + "accept_len": 3.1, + } + + +class TestEAGLERadixCache(CustomTestCase): + BASE_CONFIG = { + "model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3, + "speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3, + "speculative_algorithm": "EAGLE3", + "speculative_num_steps": 2, + "speculative_eagle_topk": 1, + "speculative_num_draft_tokens": 3, + "mem_fraction_static": 0.7, + "cuda_graph_max_bs": 5, + "dtype": "float16", + } + + def test_correctness(self): + configs = [ + # Basic config + self.BASE_CONFIG, + # Chunked prefill + {**self.BASE_CONFIG, "chunked_prefill_size": 64}, + # Chunked prefill & Page Size > 1 + {**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4}, + ] + + for i, config in enumerate(configs): + with self.subTest(i=i): + print(f"{config=}") + engine = sgl.Engine(**config, log_level="info", decode_log_interval=10) + try: + self._test_acc_length(engine) + finally: + engine.shutdown() + print("=" * 100) + + def _test_acc_length(self, engine): + warmup_prompt = [ + "Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:", + ] + sampling_params = {"temperature": 0, "max_new_tokens": 512} + output = engine.generate(warmup_prompt, sampling_params) + test_prompt = [ + "<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n" + ] + output = engine.generate(test_prompt, sampling_params) + output = output[0] + + if "spec_verify_ct" in output["meta_info"]: + acc_length = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["spec_verify_ct"] + ) + else: + acc_length = 1.0 + + speed = ( + output["meta_info"]["completion_tokens"] + / output["meta_info"]["e2e_latency"] + ) + print(f"{acc_length=:.4f}, {speed=}") + + self.assertGreater(acc_length, 2.5) @unittest.skipIf(is_in_ci(), "To reduce the CI execution time.") diff --git a/test/srt/test_radix_cache_unit.py b/test/srt/test_radix_cache_unit.py index 8cb75fb0b..f8708eaf3 100644 --- a/test/srt/test_radix_cache_unit.py +++ b/test/srt/test_radix_cache_unit.py @@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase): result.device_indices, torch.tensor([10, 20], dtype=torch.int64) ) + def test_insert_and_match_eagle(self): + """Test insert and match operations for EAGLE.""" + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=1, + disable=False, + is_eagle=True, + ) + + key = RadixKey([1, 2, 3, 4]) + value = torch.tensor([10, 20, 30, 40], dtype=torch.int64) + prefix_len = cache.insert(key, value) + + self.assertEqual(prefix_len, 0) # No existing prefix + self.assertEqual( + cache.total_size(), 3 + ) # The last token is ignored in bigram key + self.assertEqual(cache.evictable_size(), 3) + + # Test match_prefix + result = cache.match_prefix(RadixKey([1, 2, 3, 4])) + self.assertEqual(len(result.device_indices), 3) + torch.testing.assert_close( + result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64) + ) + + # Test partial match + result = cache.match_prefix(RadixKey([1, 2])) + self.assertEqual(len(result.device_indices), 1) + torch.testing.assert_close( + result.device_indices, torch.tensor([10], dtype=torch.int64) + ) + + def test_insert_and_match_eagle_page_size(self): + """Test insert and match operations for EAGLE and page_size > 1.""" + cache = RadixCache( + req_to_token_pool=None, + token_to_kv_pool_allocator=None, + page_size=2, + disable=False, + is_eagle=True, + ) + + key = RadixKey([1, 2, 3]) + value = torch.tensor([10, 20, 30], dtype=torch.int64) + prefix_len = cache.insert(key, value) + + self.assertEqual(prefix_len, 0) # No existing prefix + self.assertEqual(cache.total_size(), 2) # only one page is inserted + self.assertEqual(cache.evictable_size(), 2) + + # Test match_prefix + result = cache.match_prefix(RadixKey([1, 2, 3, 4])) + self.assertEqual(len(result.device_indices), 2) + torch.testing.assert_close( + result.device_indices, torch.tensor([10, 20], dtype=torch.int64) + ) + + # Test unmatched + result = cache.match_prefix(RadixKey([1, 2])) + self.assertEqual(len(result.device_indices), 0) + torch.testing.assert_close( + result.device_indices, torch.tensor([], dtype=torch.int64) + ) + def test_insert_with_none_value(self): """Test insert with None value (should use token_ids as list).""" cache = RadixCache(