Fix eagle radix cache (#10846)
This commit is contained in:
@@ -547,6 +547,8 @@ class Req:
|
||||
self.host_hit_length = 0
|
||||
# The node to lock until for swa radix tree lock ref
|
||||
self.swa_uuid_for_lock: Optional[int] = None
|
||||
# The prefix length of the last prefix matching
|
||||
self.last_matched_prefix_len: int = 0
|
||||
|
||||
# Whether or not if it is chunked. It increments whenever
|
||||
# it is chunked, and decrement whenever chunked request is
|
||||
@@ -701,6 +703,7 @@ class Req:
|
||||
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
||||
),
|
||||
)
|
||||
self.last_matched_prefix_len = len(self.prefix_indices)
|
||||
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
||||
|
||||
def adjust_max_prefix_ids(self):
|
||||
|
||||
@@ -756,6 +756,7 @@ class Scheduler(
|
||||
disable=server_args.disable_radix_cache,
|
||||
enable_kv_cache_events=self.enable_kv_cache_events,
|
||||
eviction_policy=server_args.radix_eviction_policy,
|
||||
is_eagle=self.spec_algorithm.is_eagle(),
|
||||
)
|
||||
|
||||
if (
|
||||
|
||||
@@ -23,7 +23,7 @@ import heapq
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -159,6 +159,16 @@ def get_child_key(key: RadixKey, page_size: int = 1):
|
||||
return (key.extra_key, plain_key)
|
||||
|
||||
|
||||
def _convert_to_bigram_key(tokens: List[int]) -> List[Tuple[int, int]]:
|
||||
# EAGLE uses bigram keys in the radix tree since draft sequence is the one-token-shifted version of target
|
||||
# [1, 2, 3, 4] -> [(1,2), (2,3), (3,4)]
|
||||
if len(tokens) < 2:
|
||||
return []
|
||||
if isinstance(tokens[0], tuple):
|
||||
return tokens
|
||||
return [(tokens[i], tokens[i + 1]) for i in range(len(tokens) - 1)]
|
||||
|
||||
|
||||
class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -168,6 +178,7 @@ class RadixCache(BasePrefixCache):
|
||||
disable: bool = False,
|
||||
enable_kv_cache_events: bool = False,
|
||||
eviction_policy: str = "lru",
|
||||
is_eagle: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
@@ -175,6 +186,7 @@ class RadixCache(BasePrefixCache):
|
||||
self.disable = disable
|
||||
self.enable_kv_cache_events = enable_kv_cache_events
|
||||
self.kv_event_queue = []
|
||||
self.is_eagle = is_eagle
|
||||
|
||||
if self.token_to_kv_pool_allocator:
|
||||
self.device = self.token_to_kv_pool_allocator.device
|
||||
@@ -188,6 +200,11 @@ class RadixCache(BasePrefixCache):
|
||||
self.key_match_fn = partial(_key_match_paged, page_size=page_size)
|
||||
self.get_child_key_fn = partial(get_child_key, page_size=page_size)
|
||||
|
||||
if is_eagle:
|
||||
self.key_convert_fn = _convert_to_bigram_key
|
||||
else:
|
||||
self.key_convert_fn = lambda key: key
|
||||
|
||||
if eviction_policy.lower() == "lru":
|
||||
self.eviction_strategy: EvictionStrategy = LRUStrategy()
|
||||
elif eviction_policy.lower() == "lfu":
|
||||
@@ -248,6 +265,8 @@ class RadixCache(BasePrefixCache):
|
||||
to expose a precise boundary; this structural refinement improves
|
||||
subsequent match efficiency and does not duplicate data.
|
||||
"""
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if self.disable or len(key) == 0:
|
||||
return MatchResult(
|
||||
device_indices=torch.empty(
|
||||
@@ -278,8 +297,15 @@ class RadixCache(BasePrefixCache):
|
||||
if self.disable:
|
||||
return 0
|
||||
|
||||
key.token_ids = self.key_convert_fn(key.token_ids)
|
||||
|
||||
if value is None:
|
||||
value = torch.tensor(key.token_ids, dtype=torch.int64)
|
||||
|
||||
if self.is_eagle:
|
||||
# Make sure the value len equal to the EAGLE bigram key len
|
||||
value = value[: len(key)]
|
||||
|
||||
return self._insert_helper(self.root_node, key, value)
|
||||
|
||||
def cache_finished_req(self, req: Req):
|
||||
@@ -293,28 +319,39 @@ class RadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
||||
all_token_len = len(token_ids)
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
if self.is_eagle:
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[page_aligned_len:])
|
||||
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
RadixKey(token_ids[:page_aligned_len], req.extra_key),
|
||||
RadixKey(token_ids[:page_aligned_token_len], req.extra_key),
|
||||
page_aligned_kv_indices,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
@@ -326,19 +363,32 @@ class RadixCache(BasePrefixCache):
|
||||
return
|
||||
|
||||
token_ids = req.fill_ids
|
||||
all_token_len = len(token_ids)
|
||||
# The actual kv len for EAGLE is len(token_ids), since EAGLE uses bigram key
|
||||
actual_kv_len = all_token_len - 1 if self.is_eagle else all_token_len
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : len(token_ids)
|
||||
req.req_pool_idx, :all_token_len
|
||||
]
|
||||
|
||||
if self.page_size != 1:
|
||||
page_aligned_len = len(kv_indices) // self.page_size * self.page_size
|
||||
page_aligned_len = actual_kv_len // self.page_size * self.page_size
|
||||
page_aligned_kv_indices = kv_indices[:page_aligned_len].to(
|
||||
dtype=torch.int64, copy=True
|
||||
)
|
||||
else:
|
||||
page_aligned_len = len(kv_indices)
|
||||
page_aligned_len = actual_kv_len
|
||||
page_aligned_kv_indices = kv_indices.to(dtype=torch.int64, copy=True)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_len]
|
||||
|
||||
# For EAGLE, the page_aligned_len is for the bigram key, the normal key len should +1
|
||||
page_aligned_token_len = (
|
||||
page_aligned_len + 1 if self.is_eagle else page_aligned_len
|
||||
)
|
||||
page_aligned_token_ids = token_ids[:page_aligned_token_len]
|
||||
|
||||
old_prefix_len = len(req.prefix_indices)
|
||||
if self.is_eagle and old_prefix_len > req.last_matched_prefix_len:
|
||||
# prefix_indices attached partial part (for page_size > 1) and one unmatched token (for EAGLE)
|
||||
old_prefix_len -= 1
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(
|
||||
@@ -346,29 +396,40 @@ class RadixCache(BasePrefixCache):
|
||||
page_aligned_kv_indices,
|
||||
chunked=chunked,
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices[old_prefix_len:new_prefix_len])
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node, _, _ = self.match_prefix(
|
||||
RadixKey(token_ids=page_aligned_token_ids, extra_key=req.extra_key)
|
||||
)
|
||||
self.req_to_token_pool.write(
|
||||
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
||||
new_indices[len(req.prefix_indices) :],
|
||||
(req.req_pool_idx, slice(old_prefix_len, len(new_indices))),
|
||||
new_indices[old_prefix_len:],
|
||||
)
|
||||
|
||||
# The last_matched_prefix_len is not always equal to len(req.prefix_indices)
|
||||
# since for page_size > 1, the partial part is added to req.prefix_indices, but that part of kv indices is not added to the tree.
|
||||
# It should be freed in the next cache_unfinished_req and final cache_finished_req to avoid memory leak.
|
||||
# So we introduce this `last_matched_prefix_len` field to make sure the partial part can be freed correctly.
|
||||
req.last_matched_prefix_len = len(new_indices)
|
||||
|
||||
self.dec_lock_ref(req.last_node)
|
||||
self.inc_lock_ref(new_last_node)
|
||||
|
||||
# `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
|
||||
if self.page_size != 1:
|
||||
# Handle partial page, the partial part should be freed in the next cache_unfinished_req and final cache_finished_req.
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[len(new_indices) :]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
if self.is_eagle:
|
||||
# Attach the kv index of the last token for EAGLE, it can be used in chunked prefill
|
||||
req.prefix_indices = torch.cat(
|
||||
[new_indices, kv_indices[actual_kv_len:]]
|
||||
)
|
||||
else:
|
||||
req.prefix_indices = new_indices
|
||||
req.last_node = new_last_node
|
||||
|
||||
def pretty_print(self):
|
||||
|
||||
@@ -77,7 +77,8 @@ DEFAULT_MODEL_NAME_FOR_TEST_W8A8_WITH_MOE = "nytopop/Qwen3-30B-A3B.w8a8"
|
||||
# EAGLE
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST = "meta-llama/Llama-2-7b-chat-hf"
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST = "lmsys/sglang-EAGLE-llama2-chat-7B"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B"
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3 = "meta-llama/Llama-3.1-8B-Instruct"
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3 = "lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B"
|
||||
DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
||||
"meta-llama/Llama-3.1-8B-Instruct"
|
||||
)
|
||||
|
||||
@@ -9,6 +9,8 @@ from sglang.srt.utils import kill_process_tree
|
||||
from sglang.test.test_utils import (
|
||||
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
DEFAULT_URL_FOR_TEST,
|
||||
@@ -35,6 +37,11 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
}
|
||||
NUM_CONFIGS = 2
|
||||
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.9,
|
||||
"accept_len": 3.6,
|
||||
}
|
||||
|
||||
def setUp(self):
|
||||
self.prompt = "Today is a sunny day and I like"
|
||||
self.sampling_params = {"temperature": 0, "max_new_tokens": 8}
|
||||
@@ -63,6 +70,7 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
self._test_eos_token(engine)
|
||||
self._test_acc_length(engine)
|
||||
finally:
|
||||
engine.flush_cache() # check engine alive
|
||||
engine.shutdown()
|
||||
print("=" * 100)
|
||||
|
||||
@@ -92,7 +100,9 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
"avg_spec_accept_length"
|
||||
]
|
||||
print(f"{avg_spec_accept_length=}")
|
||||
self.assertGreater(avg_spec_accept_length, 1.9)
|
||||
self.assertGreater(
|
||||
avg_spec_accept_length, self.THRESHOLDS["batch_avg_accept_len"]
|
||||
)
|
||||
|
||||
def _test_eos_token(self, engine):
|
||||
prompt = "[INST] <<SYS>>\nYou are a helpful assistant.\n<</SYS>>\nToday is a sunny day and I like [/INST]"
|
||||
@@ -131,10 +141,7 @@ class TestEAGLEEngine(CustomTestCase):
|
||||
)
|
||||
print(f"{acc_length=:.4f}, {speed=}")
|
||||
|
||||
if engine.server_args.model_path == DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST:
|
||||
self.assertGreater(acc_length, 3.6)
|
||||
else:
|
||||
self.assertGreater(acc_length, 2.5)
|
||||
self.assertGreater(acc_length, self.THRESHOLDS["accept_len"])
|
||||
|
||||
|
||||
class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||
@@ -151,12 +158,16 @@ class TestEAGLEEngineTokenMap(TestEAGLEEngine):
|
||||
"dtype": "float16",
|
||||
}
|
||||
NUM_CONFIGS = 1
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.9,
|
||||
"accept_len": 2.5,
|
||||
}
|
||||
|
||||
|
||||
class TestEAGLE3Engine(TestEAGLEEngine):
|
||||
BASE_CONFIG = {
|
||||
"model_path": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"speculative_draft_model_path": "jamesliu1/sglang-EAGLE3-Llama-3.1-Instruct-8B",
|
||||
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
"speculative_algorithm": "EAGLE3",
|
||||
"speculative_num_steps": 5,
|
||||
"speculative_eagle_topk": 16,
|
||||
@@ -166,6 +177,72 @@ class TestEAGLE3Engine(TestEAGLEEngine):
|
||||
"dtype": "float16",
|
||||
}
|
||||
NUM_CONFIGS = 1
|
||||
THRESHOLDS = {
|
||||
"batch_avg_accept_len": 1.75,
|
||||
"accept_len": 3.1,
|
||||
}
|
||||
|
||||
|
||||
class TestEAGLERadixCache(CustomTestCase):
|
||||
BASE_CONFIG = {
|
||||
"model_path": DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST_EAGLE3,
|
||||
"speculative_draft_model_path": DEFAULT_MODEL_NAME_FOR_TEST_EAGLE3,
|
||||
"speculative_algorithm": "EAGLE3",
|
||||
"speculative_num_steps": 2,
|
||||
"speculative_eagle_topk": 1,
|
||||
"speculative_num_draft_tokens": 3,
|
||||
"mem_fraction_static": 0.7,
|
||||
"cuda_graph_max_bs": 5,
|
||||
"dtype": "float16",
|
||||
}
|
||||
|
||||
def test_correctness(self):
|
||||
configs = [
|
||||
# Basic config
|
||||
self.BASE_CONFIG,
|
||||
# Chunked prefill
|
||||
{**self.BASE_CONFIG, "chunked_prefill_size": 64},
|
||||
# Chunked prefill & Page Size > 1
|
||||
{**self.BASE_CONFIG, "chunked_prefill_size": 64, "page_size": 4},
|
||||
]
|
||||
|
||||
for i, config in enumerate(configs):
|
||||
with self.subTest(i=i):
|
||||
print(f"{config=}")
|
||||
engine = sgl.Engine(**config, log_level="info", decode_log_interval=10)
|
||||
try:
|
||||
self._test_acc_length(engine)
|
||||
finally:
|
||||
engine.shutdown()
|
||||
print("=" * 100)
|
||||
|
||||
def _test_acc_length(self, engine):
|
||||
warmup_prompt = [
|
||||
"Human: Give me a fully functional FastAPI server. Show the python code.\n\nAssistant:",
|
||||
]
|
||||
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||
output = engine.generate(warmup_prompt, sampling_params)
|
||||
test_prompt = [
|
||||
"<|start_header_id|>system<|end_header_id|>\n\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.<|eot_id|><|start_header_id|>user<|end_header_id|>\n\nGive me a fully functional FastAPI server. Show the python code.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n\n"
|
||||
]
|
||||
output = engine.generate(test_prompt, sampling_params)
|
||||
output = output[0]
|
||||
|
||||
if "spec_verify_ct" in output["meta_info"]:
|
||||
acc_length = (
|
||||
output["meta_info"]["completion_tokens"]
|
||||
/ output["meta_info"]["spec_verify_ct"]
|
||||
)
|
||||
else:
|
||||
acc_length = 1.0
|
||||
|
||||
speed = (
|
||||
output["meta_info"]["completion_tokens"]
|
||||
/ output["meta_info"]["e2e_latency"]
|
||||
)
|
||||
print(f"{acc_length=:.4f}, {speed=}")
|
||||
|
||||
self.assertGreater(acc_length, 2.5)
|
||||
|
||||
|
||||
@unittest.skipIf(is_in_ci(), "To reduce the CI execution time.")
|
||||
|
||||
@@ -307,6 +307,72 @@ class TestRadixCache(unittest.TestCase):
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle(self):
|
||||
"""Test insert and match operations for EAGLE."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=1,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3, 4])
|
||||
value = torch.tensor([10, 20, 30, 40], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(
|
||||
cache.total_size(), 3
|
||||
) # The last token is ignored in bigram key
|
||||
self.assertEqual(cache.evictable_size(), 3)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 3)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test partial match
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 1)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_and_match_eagle_page_size(self):
|
||||
"""Test insert and match operations for EAGLE and page_size > 1."""
|
||||
cache = RadixCache(
|
||||
req_to_token_pool=None,
|
||||
token_to_kv_pool_allocator=None,
|
||||
page_size=2,
|
||||
disable=False,
|
||||
is_eagle=True,
|
||||
)
|
||||
|
||||
key = RadixKey([1, 2, 3])
|
||||
value = torch.tensor([10, 20, 30], dtype=torch.int64)
|
||||
prefix_len = cache.insert(key, value)
|
||||
|
||||
self.assertEqual(prefix_len, 0) # No existing prefix
|
||||
self.assertEqual(cache.total_size(), 2) # only one page is inserted
|
||||
self.assertEqual(cache.evictable_size(), 2)
|
||||
|
||||
# Test match_prefix
|
||||
result = cache.match_prefix(RadixKey([1, 2, 3, 4]))
|
||||
self.assertEqual(len(result.device_indices), 2)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([10, 20], dtype=torch.int64)
|
||||
)
|
||||
|
||||
# Test unmatched
|
||||
result = cache.match_prefix(RadixKey([1, 2]))
|
||||
self.assertEqual(len(result.device_indices), 0)
|
||||
torch.testing.assert_close(
|
||||
result.device_indices, torch.tensor([], dtype=torch.int64)
|
||||
)
|
||||
|
||||
def test_insert_with_none_value(self):
|
||||
"""Test insert with None value (should use token_ids as list)."""
|
||||
cache = RadixCache(
|
||||
|
||||
Reference in New Issue
Block a user