[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -523,7 +521,7 @@ class ScheduleBatch:
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# Batch configs
|
||||
@@ -596,7 +594,7 @@ class ScheduleBatch:
|
||||
cls,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
@@ -606,7 +604,7 @@ class ScheduleBatch:
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
@@ -637,19 +635,19 @@ class ScheduleBatch:
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
logger.error(
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
@@ -917,12 +915,12 @@ class ScheduleBatch:
|
||||
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
bs = len(self.reqs) * buf_multiplier
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -945,6 +943,10 @@ class ScheduleBatch:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
@@ -958,18 +960,15 @@ class ScheduleBatch:
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool.available_size() > 0
|
||||
self.token_to_kv_pool_allocator.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
@@ -983,7 +982,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
@@ -992,7 +991,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool.available_size()
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(
|
||||
residual_size, self.token_to_kv_pool_allocator.free
|
||||
)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sum of all sequence lengths
|
||||
|
||||
Reference in New Issue
Block a user