diff --git a/python/pyproject.toml b/python/pyproject.toml index 5cb35b006..eaba878f9 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -102,6 +102,8 @@ dev = ["sglang[test]"] "srt/layers/moe/fused_moe_triton/configs/*/*.json", "srt/layers/quantization/configs/*.json", "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", + "srt/speculative/cpp_lookahead/*.cpp", + "srt/speculative/cpp_lookahead/*.h", ] [tool.setuptools.packages.find] diff --git a/python/sglang/bench_serving.py b/python/sglang/bench_serving.py index 401f702a4..ea670d97f 100644 --- a/python/sglang/bench_serving.py +++ b/python/sglang/bench_serving.py @@ -1110,7 +1110,8 @@ def sample_sharegpt_requests( add_generation_prompt=True, tokenize=False, ) - prompt = prompt.replace(tokenizer.bos_token, "") + if tokenizer.bos_token: + prompt = prompt.replace(tokenizer.bos_token, "") prompt_token_ids = tokenizer.encode(prompt) completion = dataset[i][1] diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index a5b207c77..107ac0cbd 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -29,6 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.utils import ( is_flashinfer_available, is_sm100_supported, @@ -317,7 +318,9 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -422,7 +425,9 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): @@ -638,7 +643,9 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -651,7 +658,9 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): decode_wrappers = decode_wrappers or self.decode_wrappers self.call_begin_forward( @@ -673,7 +682,9 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): assert self.sliding_window_size is not None for wrapper_id in range(2): @@ -721,7 +732,9 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -753,7 +766,9 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], seq_lens_cpu: Optional[torch.Tensor], use_sliding_window_kv_pool: bool = False, ): @@ -858,7 +873,9 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -873,7 +890,9 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): if use_ragged: # TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu @@ -909,7 +928,9 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -955,7 +976,9 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], ): for wrapper_id in range(2): if wrapper_id == 0: @@ -997,7 +1020,9 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ], use_sliding_window_kv_pool: bool = False, ): bs = len(seq_lens) @@ -1024,8 +1049,8 @@ class FlashInferIndicesUpdaterPrefill: qo_indptr = qo_indptr[: bs + 1] custom_mask = None else: - assert isinstance(spec_info, EagleDraftInput) or isinstance( - spec_info, EagleVerifyInput + assert isinstance( + spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput) ) kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index c60864d64..9eaebad7a 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -74,6 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -950,7 +951,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None + spec_info: Optional[ + Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] + ] = None # Whether to return hidden states return_hidden_states: bool = False @@ -1600,7 +1603,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): self.forward_mode = ForwardMode.DECODE bs = len(self.reqs) - if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): + if ( + self.spec_algorithm.is_eagle() + or self.spec_algorithm.is_standalone() + or self.spec_algorithm.is_lookahead() + ): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. return @@ -1975,7 +1982,9 @@ class ModelWorkerBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None + spec_info: Optional[ + Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput] + ] = None # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = -1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index a4f47819d..fc5525afa 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -385,6 +385,18 @@ class Scheduler( target_worker=self.tp_worker, dp_rank=dp_rank, ) + elif self.spec_algorithm.is_lookahead(): + from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker + + self.draft_worker = LOOKAHEADWorker( + gpu_id=gpu_id, + tp_rank=tp_rank, + moe_ep_rank=moe_ep_rank, + server_args=server_args, + nccl_port=port_args.nccl_port, + target_worker=self.tp_worker, + dp_rank=dp_rank, + ) else: self.draft_worker = None @@ -740,8 +752,8 @@ class Scheduler( else ( server_args.speculative_num_draft_tokens + ( - server_args.speculative_eagle_topk - * server_args.speculative_num_steps + (server_args.speculative_eagle_topk or 1) + * (server_args.speculative_num_steps or 1) ) ) ) @@ -784,7 +796,7 @@ class Scheduler( token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, draft_token_to_kv_pool=( None - if self.draft_worker is None + if self.draft_worker is None or self.spec_algorithm.is_lookahead() else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, @@ -821,7 +833,7 @@ class Scheduler( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), draft_token_to_kv_pool=( None - if self.draft_worker is None + if self.draft_worker is None or self.spec_algorithm.is_lookahead() else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, @@ -2358,9 +2370,8 @@ class Scheduler( self.req_to_token_pool.clear() self.token_to_kv_pool_allocator.clear() - if not self.spec_algorithm.is_none(): - self.draft_worker.model_runner.req_to_token_pool.clear() - self.draft_worker.model_runner.token_to_kv_pool_allocator.clear() + if self.draft_worker: + self.draft_worker.clear_cache_pool() self.num_generated_tokens = 0 self.forward_ct_decode = 0 diff --git a/python/sglang/srt/managers/tokenizer_manager.py b/python/sglang/srt/managers/tokenizer_manager.py index 6e9884dc3..09059277d 100644 --- a/python/sglang/srt/managers/tokenizer_manager.py +++ b/python/sglang/srt/managers/tokenizer_manager.py @@ -84,6 +84,7 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat from sglang.srt.metrics.collector import TokenizerMetricsCollector from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.tracing.trace import ( trace_get_proc_propagate_context, trace_req_finish, @@ -174,6 +175,15 @@ class TokenizerManager(TokenizerCommunicatorMixin): self.image_token_id = self.model_config.image_token_id self.max_req_input_len = None # Will be set later in engine.py + speculative_algorithm = SpeculativeAlgorithm.from_string( + server_args.speculative_algorithm + ) + self.reserve_input_token_num = ( + 0 + if speculative_algorithm.is_none() + else server_args.speculative_num_draft_tokens + ) + if self.model_config.is_multimodal: import_processors() try: @@ -618,6 +628,7 @@ class TokenizerManager(TokenizerCommunicatorMixin): _max_req_len = self.context_len input_token_num = len(input_ids) if input_ids is not None else 0 + input_token_num += self.reserve_input_token_num if input_token_num >= self.context_len: if self.server_args.allow_auto_truncate: logger.warning( diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 8cab1b69f..460776995 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -275,6 +275,7 @@ class CudaGraphRunner: if ( model_runner.spec_algorithm.is_eagle() or model_runner.spec_algorithm.is_standalone() + or model_runner.spec_algorithm.is_lookahead() ): if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen") @@ -441,11 +442,21 @@ class CudaGraphRunner: forward_batch.can_run_tbo if self.enable_two_batch_overlap else True ) + is_lookahead_supported = ( + ( + forward_batch.batch_size * self.num_tokens_per_bs + == forward_batch.input_ids.numel() + ) + if self.model_runner.spec_algorithm.is_lookahead() + else True + ) + return ( is_bs_supported and is_encoder_lens_supported and is_tbo_supported and capture_hidden_mode_matches + and is_lookahead_supported ) def capture(self) -> None: @@ -856,6 +867,20 @@ class CudaGraphRunner: seq_lens_cpu=None, ) + elif self.model_runner.spec_algorithm.is_lookahead(): + from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput + + spec_info = LookaheadVerifyInput( + draft_token=None, + tree_mask=self.custom_mask, + positions=None, + retrive_index=None, + retrive_next_token=None, + retrive_next_sibling=None, + draft_token_num=self.num_tokens_per_bs, + ) + spec_info.capture_hidden_mode = CaptureHiddenMode.NULL + return spec_info diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 1355353bb..0019fab76 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -1402,7 +1402,7 @@ class ModelRunner: if self.is_hybrid_gdn: max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size) - if not self.spec_algorithm.is_none(): + if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone(): if self.is_draft_worker: self.max_total_num_tokens = self.server_args.draft_runner_cache_size max_num_reqs = self.server_args.max_num_reqs diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index d55b56794..c1fb677b3 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -286,6 +286,14 @@ class ServerArgs: speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None speculative_attention_mode: str = "prefill" + # For lookahead only + speculative_lookahead_min_match_window_size: int = 1 + speculative_lookahead_max_match_window_size: int = 12 + speculative_lookahead_min_bfs_breadth: int = 1 + speculative_lookahead_max_bfs_breadth: int = 10 + speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS" + speculative_lookahead_branch_length: int = 18 + speculative_lookahead_capacity: int = 10 * 1000 * 1000 # Expert parallelism ep_size: int = 1 @@ -529,7 +537,7 @@ class ServerArgs: # Standalone speculative decoding needs more memory than other speculative # decoding algorithms since the draft model is typically larger. reserved_mem += 6 * 1024 - else: + elif self.speculative_algorithm != "LOOKAHEAD": reserved_mem += 2 * 1024 if self.enable_dp_attention: reserved_mem += 4 * 1024 @@ -780,11 +788,11 @@ class ServerArgs: self.speculative_algorithm = "EAGLE" if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"): - if self.speculative_algorithm == "STANDALONE": + if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention: # TODO: support dp attention for standalone speculative decoding - assert ( - self.enable_dp_attention is False - ), "Currently standalone speculative decoding does not support dp attention." + raise ValueError( + "Currently standalone speculative decoding does not support dp attention." + ) if self.max_running_requests is None: self.max_running_requests = 48 self.disable_overlap_schedule = True @@ -858,6 +866,39 @@ class ServerArgs: # If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded. # assert self.speculative_num_steps < self.speculative_num_draft_tokens + if self.speculative_algorithm == "LOOKAHEAD": + if not self.device.startswith("cuda"): + raise ValueError( + "Lookahead speculative decoding only supports CUDA device." + ) + if self.max_running_requests is None: + self.max_running_requests = 48 + self.disable_overlap_schedule = True + self.enable_mixed_chunk = False + self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth + if self.speculative_num_draft_tokens is None: + # TODO: Do better auto choose in the future + self.speculative_num_draft_tokens = ( + self.speculative_lookahead_max_match_window_size + ) + logger.warning( + "The overlap scheduler and mixed chunked prefill are disabled because of " + "using lookahead speculative decoding." + ) + if ( + self.speculative_eagle_topk > 1 + and self.page_size > 1 + and self.attention_backend != "flashinfer" + ): + raise ValueError( + "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." + ) + + if self.enable_dp_attention: + # TODO: support dp attention for lookahead speculative decoding + raise ValueError( + "Currently lookahead speculative decoding does not support dp attention." + ) # GGUF if ( self.load_format == "auto" or self.load_format == "gguf" @@ -1690,7 +1731,7 @@ class ServerArgs: parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"], + choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"], help="Speculative algorithm.", ) parser.add_argument( @@ -1750,6 +1791,50 @@ class ServerArgs: help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", default=ServerArgs.speculative_attention_mode, ) + # Lookahead speculative decoding + parser.add_argument( + "--speculative-lookahead-min-match-window-size", + type=int, + default=ServerArgs.speculative_lookahead_min_match_window_size, + help="The minimum window size for pattern matching in lookahead speculative decoding.", + ) + parser.add_argument( + "--speculative-lookahead-max-match-window-size", + type=int, + default=ServerArgs.speculative_lookahead_max_match_window_size, + help="The maximum window size for pattern matching in lookahead speculative decoding.", + ) + parser.add_argument( + "--speculative-lookahead-min-bfs-breadth", + type=int, + default=ServerArgs.speculative_lookahead_min_bfs_breadth, + help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", + ) + parser.add_argument( + "--speculative-lookahead-max-bfs-breadth", + type=int, + default=ServerArgs.speculative_lookahead_max_bfs_breadth, + help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", + ) + parser.add_argument( + "--speculative-lookahead-match-type", + type=str, + choices=["BFS", "PROB"], + default=ServerArgs.speculative_lookahead_match_type, + help="The match type for cache tree.", + ) + parser.add_argument( + "--speculative-lookahead-branch-length", + type=int, + default=ServerArgs.speculative_lookahead_branch_length, + help="The branch length for lookahead speculative decoding.", + ) + parser.add_argument( + "--speculative-lookahead-capacity", + type=int, + default=ServerArgs.speculative_lookahead_capacity, + help="The cache capacity for lookahead speculative decoding.", + ) # Expert parallelism parser.add_argument( diff --git a/python/sglang/srt/speculative/cpp_lookahead/.clang-format b/python/sglang/srt/speculative/cpp_lookahead/.clang-format new file mode 120000 index 000000000..5a7a8cea7 --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/.clang-format @@ -0,0 +1 @@ +../../../../../sgl-kernel/.clang-format \ No newline at end of file diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp b/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp new file mode 100644 index 000000000..c47ebcd8d --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp @@ -0,0 +1,372 @@ +#include "lookahead.h" + +#include +#include + +namespace lookahead { + +struct Node { + std::unordered_map next; +}; + +Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector& tree, int root) { + Lookahead::Result info; + std::vector prevs; + info.token.reserve(draft_token_num); + prevs.reserve(draft_token_num); + std::queue> queue; + info.token.emplace_back(last_token); + prevs.emplace_back(-1); + + for (auto [token, next] : tree[root].next) { + queue.emplace(token, next, 0); + } + while (queue.size()) { + auto [token, next, prev] = queue.front(); + queue.pop(); + info.token.emplace_back(token); + prevs.emplace_back(prev); + for (auto [t, n] : tree[next].next) { + queue.emplace(t, n, info.token.size() - 1); + } + } + + // zero padding to length + while (info.token.size() < draft_token_num) { + info.token.emplace_back(0); + prevs.emplace_back(0); + } + + int n = info.token.size(); + info.mask.resize(n * n, 0); + info.mask[0] = 1; + for (int i = 0; i < n; ++i) { + if (prevs[i] != -1) { + memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1); + } + info.mask[i * n + i] = 1; + } + + return info; +} + +Lookahead::Lookahead(size_t capacity, const Param& param) { + param_ = param; + nodes_.resize(capacity); + for (auto& node : nodes_) { + node_pool_.emplace_back(&node); + } + free_node_count_ = node_pool_.size(); + root_ = getNode(); + + if (!(param_.branch_length > 1)) { + throw std::runtime_error( + "param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length)); + } + if (!(param_.min_match_window_size > 0)) { + throw std::runtime_error( + "min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size)); + } + if (!(param_.min_match_window_size <= param_.max_match_window_size)) { + throw std::runtime_error( + "min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " + + std::to_string(param_.min_match_window_size) + + ", max_match_window_size: " + std::to_string(param_.max_match_window_size)); + } + if (!(param_.max_match_window_size < param_.branch_length)) { + throw std::runtime_error( + "max_match_window_size must be less than branch_length, current max_match_window_size: " + + std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length)); + } + if (!(param_.min_bfs_breadth > 0)) { + throw std::runtime_error( + "min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth)); + } + if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) { + throw std::runtime_error( + "min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " + + std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth)); + } + if (!(param_.draft_token_num > 0)) { + throw std::runtime_error( + "draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num)); + } + for (auto config : param_.batch_draft_token_num) { + if (config != std::numeric_limits::max()) { + if (!(config <= param_.draft_token_num)) { + throw std::runtime_error( + "batch_draft_token_num config value " + std::to_string(config) + + " must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num)); + } + } + } + for (auto config : param_.batch_min_match_window_size) { + if (config != std::numeric_limits::max()) { + if (!(config >= param_.min_match_window_size)) { + throw std::runtime_error( + "batch_min_match_window_size config value " + std::to_string(config) + + " must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size)); + } + if (!(config <= param_.max_match_window_size)) { + throw std::runtime_error( + "batch_min_match_window_size config value " + std::to_string(config) + + " must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size)); + } + } + } + + quit_flag_ = false; + insert_worker_ = std::thread(&Lookahead::insert, this); +} + +Lookahead::~Lookahead() { + quit_flag_ = true; + insert_queue_.close(); + insert_worker_.join(); +} + +std::vector> +Lookahead::match(const std::vector& tokens, size_t batch_size) const { + auto draft_token_num = param_.get_draft_token_num(batch_size); + auto min_match_window_size = param_.get_min_match_window_size(batch_size); + auto max_match_window_size = param_.max_match_window_size; + std::vector> result; + result.reserve(param_.max_match_window_size - param_.min_match_window_size); + for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size); + match_window_size >= param_.min_match_window_size; + --match_window_size) { + auto start = tokens.data() + tokens.size() - match_window_size; + auto end = start + match_window_size; + auto cursor = root_; + while (start != end) { + auto iter = cursor->child.find(*start); + if (iter == cursor->child.end()) { + cursor = nullptr; + break; + } + ++start; + cursor = iter->second; + } + if (cursor) { + result.emplace_back(std::make_pair(cursor, match_window_size)); + } + } + return result; +} + +void Lookahead::squeeze(size_t count) { + if (!(node_pool_.size() >= free_node_count_ + count)) { + throw std::runtime_error( + "Insufficient node size to release required nodes. " + "available to release: " + + std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count)); + } + while (count--) { + auto last = global_lru_.back(); + global_lru_.pop_back(); + + if (!last->child.empty()) { + throw std::runtime_error("The node to be released still has child nodes and cannot be released. "); + } + + last->parent->lru.erase(last->parent_lru_pos); + last->parent->sorted_children.erase(last); + last->parent->child.erase(last->token); + + node_pool_[free_node_count_++] = last; + } +} + +void Lookahead::synchronize() const { + while (!insert_queue_.empty()) { + std::this_thread::sleep_for(std::chrono::microseconds(10)); + } +} + +void Lookahead::insert() { + while (!quit_flag_) { + std::vector data; + if (!insert_queue_.dequeue(data)) { + continue; + } + const auto* token = data.data(); + size_t size = data.size(); + std::unique_lock lock(mutex_); + + for (size_t i = 0; i + param_.min_match_window_size < size; ++i) { + auto start = token + i; + auto end = start + std::min(size - i, param_.branch_length); + + if (end - start > free_node_count_) { + squeeze(end - start - free_node_count_); + } + + TrieNode* cursor = root_; + path_.clear(); + while (start != end) { + auto token = *start; + auto iter = cursor->child.find(token); + if (iter == cursor->child.end()) { + iter = cursor->child.insert({token, getNode()}).first; + auto node = iter->second; + + cursor->lru.emplace_front(node); + global_lru_.emplace_back(node); + + node->token = token; + node->parent = cursor; + node->parent_lru_pos = cursor->lru.begin(); + node->global_lru_pos = --global_lru_.end(); + node->freq = 1; + cursor->sorted_children.insert(node); + } else { + auto node = iter->second; + cursor->sorted_children.erase(node); + node->freq++; + cursor->sorted_children.insert(node); + cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos); + } + cursor = iter->second; + path_.emplace_back(cursor); + ++start; + } + + for (auto it = path_.rbegin(); it != path_.rend(); ++it) { + TrieNode* node = *it; + global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos); + } + } + } +} + +void Lookahead::asyncInsert(std::vector>&& tokens) { + for (auto&& token : tokens) { + insert_queue_.enqueue(std::move(token)); + } +} + +Lookahead::Result Lookahead::matchBFS(const std::vector& tokens, size_t batch_size) const { + std::vector> nodes = match(tokens, batch_size); + + double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) / + (param_.max_match_window_size - param_.min_match_window_size + 1); + + auto draft_token_num = param_.get_draft_token_num(batch_size); + std::vector tree(draft_token_num + 1); + int root = 0; + int cursor = 1; + + for (auto [node, depth] : nodes) { + std::queue> queue; // parent, bfs_breadth, node + queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node}); + while (queue.size() && cursor <= draft_token_num) { + auto front = queue.front(); + queue.pop(); + + auto parent = std::get<0>(front); + auto cur_breadth = std::get<1>(front); + auto iter = std::get<2>(front)->lru.begin(); + + auto breadth = std::max(1, int32_t(cur_breadth)); + for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) { + auto token = (*iter)->token; + auto pos = -1; + if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) { + pos = tit->second; + } else { + pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second; + } + queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter); + } + } + } + + return fillResult(tokens.back(), draft_token_num + 1, tree, root); +} + +Lookahead::Result Lookahead::matchProb(const std::vector& tokens, size_t batch_size) const { + std::vector> nodes = match(tokens, batch_size); + auto draft_token_num = param_.get_draft_token_num(batch_size); + + struct CompareByLastDouble { + bool operator()( + const std::tuple& a, // parent_pos, node, final_prob + const std::tuple& b) const { + return std::get<2>(a) < std::get<2>(b); + } + }; + + std::priority_queue< + std::tuple, + std::vector>, + CompareByLastDouble> + heap; + + std::vector tree(draft_token_num + 1); + + int root = 0; + int cursor = 1; + int top_k = param_.max_bfs_breadth; + + auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void { + double sum_freq = 0.0; + int count = 0; + std::list> topk_children; + for (auto* child : trie_node->sorted_children) { + sum_freq += static_cast(child->freq); + topk_children.emplace_back(child, child->freq); + if (++count >= top_k) break; + } + if (sum_freq <= 0) sum_freq = 1.0; + for (const auto& [child, freq] : topk_children) { + double norm_freq = static_cast(freq) / sum_freq * prob; + heap.emplace(parent, child, norm_freq); + } + }; + + for (auto [node, _] : nodes) { + addToHeap(root, node, 1.0); + + while (!heap.empty() && cursor <= draft_token_num) { + auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob + heap.pop(); + auto token = trie_node->token; + int pos = -1; + auto tit = tree[parent].next.find(token); + if (tit != tree[parent].next.end()) { + pos = tit->second; + } else { + pos = cursor++; + tree[parent].next[token] = pos; + } + addToHeap(pos, trie_node, prob); + } + } + + return fillResult(tokens.back(), draft_token_num + 1, tree, root); +} + +Lookahead::Result Lookahead::batchMatch(const std::vector>& tokens) const { + std::unique_lock lock(mutex_); + Result merged_result; + auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb; + for (const auto& tks : tokens) { + Result res = (this->*match_func)(tks, tokens.size()); + merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end()); + merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end()); + } + return merged_result; +} + +void Lookahead::Result::truncate(size_t n) { + if (n < token.size()) { + int full_n = token.size(); + for (int i = 1; i < n; ++i) { + memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n); + } + token.resize(n); + mask.resize(n * n); + } +} + +} // namespace lookahead diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead.h b/python/sglang/srt/speculative/cpp_lookahead/lookahead.h new file mode 100644 index 000000000..9c6c82c92 --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead.h @@ -0,0 +1,110 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "param.h" +#include "queue.h" + +namespace lookahead { + +struct TrieNode { + std::unordered_map child; + std::list::const_iterator global_lru_pos; + std::list::const_iterator parent_lru_pos; + int32_t token; + TrieNode* parent; + std::list lru; + int32_t freq = 0; + + struct CompareByFreq { + bool operator()(TrieNode* a, TrieNode* b) const { + return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b); + } + }; + std::multiset sorted_children; +}; + +class Lookahead { + std::vector nodes_; + std::vector node_pool_; + size_t free_node_count_; + std::list global_lru_; + TrieNode* root_; + std::vector path_; + Param param_; + + std::vector> match(const std::vector& tokens, size_t batch_size) const; + + void squeeze(size_t count); + + TrieNode* getNode() { + auto node = node_pool_[--free_node_count_]; + node->~TrieNode(); + new (node) TrieNode(); + return node; + } + + mutable std::mutex mutex_; + bool quit_flag_; + utils::Queue> insert_queue_; + std::thread insert_worker_; + std::vector> match_tmp_data_; + + public: + Lookahead(size_t capacity, const Param& param); + Lookahead() = default; + ~Lookahead(); + + static Lookahead& instance() { + static Lookahead instance; + return instance; + } + + void synchronize() const; + + void asyncInsert(std::vector>&& tokens); + + struct Result { + std::vector token; + std::vector mask; + + void truncate(size_t n); + }; + + Result batchMatch(const std::vector>& tokens) const; + + void reset() { + std::unique_lock lock(mutex_); + + global_lru_.clear(); + path_.clear(); + node_pool_.clear(); + for (auto& node : nodes_) { + node_pool_.emplace_back(&node); + } + free_node_count_ = node_pool_.size(); + root_ = getNode(); + } + + const Param& param() const { + return param_; + } + + private: + Result matchBFS(const std::vector& tokens, size_t batch_size) const; + Result matchProb(const std::vector& tokens, size_t batch_size) const; + + void insert(); +}; + +} // namespace lookahead diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py new file mode 100644 index 000000000..871b60878 --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py @@ -0,0 +1,140 @@ +# -*- coding: utf-8 -*- + +# from sglang.op.lookahead import Lookahead, Param + +import logging +import os +from typing import List, Tuple + +import numpy as np +from torch.utils.cpp_extension import load + +logger = logging.getLogger(__name__) + +_abs_path = os.path.dirname(os.path.abspath(__file__)) +lookahead_cache_cpp = load( + name="lookahead_cache_cpp", + sources=[ + f"{_abs_path}/lookahead_cache_binding.cpp", + f"{_abs_path}/lookahead.cpp", + ], + extra_cflags=["-O3", "-std=c++20"], +) + + +class LookaheadCache: + def __init__( + self, + branch_length=18, + min_match_window_size=1, + max_match_window_size=10, + min_bfs_breadth=1, + max_bfs_breadth=8, + draft_token_num=8, + match_type="BFS", + capacity=1000000, + ): + param = lookahead_cache_cpp.Param() + param.branch_length = branch_length + param.min_match_window_size = min_match_window_size + param.max_match_window_size = max_match_window_size + param.min_bfs_breadth = min_bfs_breadth + param.max_bfs_breadth = max_bfs_breadth + param.draft_token_num = draft_token_num + param.match_type = match_type + self.cache = lookahead_cache_cpp.Lookahead(capacity, param) + + self.default_mask = np.ones((1, 1), dtype=np.int64) + self.draft_token_num = draft_token_num + + def batch_put(self, batch_tokens: List[List[int]]): + self.cache.asyncInsert(batch_tokens) + + def synchronize(self): + self.cache.synchronize() + + def reset(self): + self.cache.reset() + + def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]: + result = self.cache.batchMatch(batch_tokens) + return np.array(result.token), np.array(result.mask) + + def leaf_paths_from_mask( + self, tokens: List[int], tree_mask: List[List[int]] + ) -> List[List[int]]: + """ + Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path). + + Args: + mask : List[List[int]] # nxn binary matrix + tokens : List[int] # token list corresponding to columns + + Returns: + List[List[int]] # token lists of only the leaf paths, preserving their order of appearance + """ + + row_sets = [ + (i, {idx for idx, v in enumerate(row) if v == 1}) + for i, row in enumerate(tree_mask) + ] + leaf_sets = [] + leaf_rows = [] + + for i, cur_set in reversed(row_sets): + if any(cur_set <= kept for kept in leaf_sets): + continue + leaf_sets.append(cur_set) + leaf_rows.append(i) + + leaf_rows.reverse() + result = [] + for r in leaf_rows: + path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1] + result.append(path) + + return result + + def debug_result( + self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None + ): + decoding_ids = decoding_ids.reshape(-1, self.draft_token_num) + decoding_masks = decoding_masks.reshape( + -1, self.draft_token_num, self.draft_token_num + ) + logger.info(f"\n{decoding_ids=}\n{decoding_masks=}") + for i in range(decoding_ids.shape[0]): + leaf_paths = self.leaf_paths_from_mask( + decoding_ids[i].tolist(), decoding_masks[i].tolist() + ) + if tokenizer is None: + logger.info(f"draft path {i}: {leaf_paths}") + else: + logger.info(f"result {i}:") + for leaf_path in leaf_paths: + logger.info( + f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}" + ) + + +# main function +if __name__ == "__main__": + format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" + logging.basicConfig( + level=logging.DEBUG, + format=format, + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + token_ids = [ + [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], + [1, 2, 3, 44, 55, 66, 77, 88, 99, 100], + ] + cache = LookaheadCache(branch_length=12, draft_token_num=8) + cache.batch_put(token_ids) + + cache.synchronize() + decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]]) + + cache.debug_result(decoding_ids, decoding_masks) diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp new file mode 100644 index 000000000..8c48a66ae --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp @@ -0,0 +1,43 @@ +#include +#include + +#include "lookahead.h" + +PYBIND11_MODULE(lookahead_cache_cpp, m) { + using namespace lookahead; + namespace py = pybind11; + m.doc() = ""; + + py::class_(m, "Lookahead") + .def(py::init(), py::arg("capacity"), py::arg("param")) + .def("asyncInsert", &Lookahead::asyncInsert, "") + .def("batchMatch", &Lookahead::batchMatch, "") + .def("reset", &Lookahead::reset, "") + .def("synchronize", &Lookahead::synchronize, ""); + + py::class_(m, "Param") + .def(py::init<>()) + .def_readwrite("enable", &Param::enable) + .def_readwrite("enable_router_mode", &Param::enable_router_mode) + .def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth) + .def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth) + .def_readwrite("min_match_window_size", &Param::min_match_window_size) + .def_readwrite("max_match_window_size", &Param::max_match_window_size) + .def_readwrite("branch_length", &Param::branch_length) + .def_readwrite("draft_token_num", &Param::draft_token_num) + .def_readwrite("match_type", &Param::match_type) + .def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size) + .def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num) + .def("get_draft_token_num", &Param::get_draft_token_num, "") + .def("get_min_match_window_size", &Param::get_min_match_window_size, "") + .def("parse", &Param::parse, "") + .def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "") + .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") + .def("detail", &Param::detail, ""); + + py::class_(m, "Result") + .def(py::init<>()) + .def_readwrite("token", &Lookahead::Result::token) + .def_readwrite("mask", &Lookahead::Result::mask) + .def("truncate", &Lookahead::Result::truncate); +} diff --git a/python/sglang/srt/speculative/cpp_lookahead/param.h b/python/sglang/srt/speculative/cpp_lookahead/param.h new file mode 100644 index 000000000..2d8b1f875 --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/param.h @@ -0,0 +1,125 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace lookahead { + +struct Param { + bool enable; + bool enable_router_mode; + size_t min_bfs_breadth; + size_t max_bfs_breadth; + size_t min_match_window_size; + size_t max_match_window_size; + size_t branch_length; + size_t draft_token_num; + std::string match_type; + + std::vector batch_min_match_window_size; + std::vector batch_draft_token_num; + + size_t get_draft_token_num(size_t batch_size) const { + if (batch_size < batch_draft_token_num.size()) { + if (batch_draft_token_num[batch_size] != + std::numeric_limits::max()) { + return batch_draft_token_num[batch_size]; + } + } + return draft_token_num - 1; + } + + size_t get_min_match_window_size(size_t batch_size) const { + if (batch_size < batch_min_match_window_size.size()) { + if (batch_min_match_window_size[batch_size] != + std::numeric_limits::max()) { + return batch_min_match_window_size[batch_size]; + } + } + return min_match_window_size; + } + + std::vector parse(const std::string& value) { + // 0-1|10,2-3|20, + std::vector result; + if (value.empty()) { + return result; + } + std::vector mark; + std::regex comma_re(","); + std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last; + for (auto p : std::vector(first, last)) { + std::cerr << "seg " << p << std::endl; + } + for (const auto& seg : std::vector(first, last)) { + std::regex pipe_re("\\|"); + std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last; + std::vector part(seg_first, seg_last); + for (auto p : part) { + std::cerr << "part " << p << std::endl; + } + if (part.size() != 2) { + throw std::runtime_error( + "failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size())); + } + std::regex endash_re("-"); + std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last; + std::vector range(range_first, range_last); + if (range.size() != 2) { + throw std::runtime_error("failed to get range, invalid config: " + value); + } + size_t L = std::atoi(range[0].c_str()); + size_t R = std::atoi(range[1].c_str()); + if (L > R || R > 128) { + throw std::runtime_error("invalid range, config: " + value); + } + if (R >= result.size()) { + result.resize(R + 1, std::numeric_limits::max()); + mark.resize(result.size(), false); + } + size_t config = std::atoi(part[1].c_str()); + do { + if (mark[L]) { + throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value); + } + mark[L] = true; + result[L] = config; + } while (++L <= R); + } + return result; + } + + void resetBatchMinMatchWindowSize(const std::string& value) { + batch_min_match_window_size = parse(value); + } + + void resetBatchReturnTokenNum(const std::string& value) { + batch_draft_token_num = parse(value); + } + + std::string detail() { + std::stringstream ss; + ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode + << ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth + << ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size + << ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num + << ", match_type = " << match_type; + ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = "; + for (int i = 0; i < batch_min_match_window_size.size(); ++i) { + ss << i << "|" << batch_min_match_window_size[i] << ","; + } + ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = "; + for (int i = 0; i < batch_draft_token_num.size(); ++i) { + ss << i << "|" << batch_draft_token_num[i] << ","; + } + return ss.str(); + } +}; + +} // namespace lookahead diff --git a/python/sglang/srt/speculative/cpp_lookahead/queue.h b/python/sglang/srt/speculative/cpp_lookahead/queue.h new file mode 100644 index 000000000..e84a0fa7b --- /dev/null +++ b/python/sglang/srt/speculative/cpp_lookahead/queue.h @@ -0,0 +1,71 @@ +#pragma once + +#include +#include + +namespace utils { + +template +class Queue { + public: + bool enqueue(T&& rhs) { + { + std::lock_guard lock(mutex_); + if (closed_) { + return false; + } + queue_.emplace(std::move(rhs)); + } + cv_.notify_one(); + return true; + } + + bool enqueue(const T& rhs) { + { + std::lock_guard lock(mutex_); + if (closed_) { + return false; + } + queue_.emplace(rhs); + } + cv_.notify_one(); + return true; + } + + bool dequeue(T& rhs) { + std::unique_lock lock(mutex_); + cv_.wait(lock, [this] { return queue_.size() || closed_; }); + if (closed_) { + return false; + } + rhs = std::move(queue_.front()); + queue_.pop(); + return true; + } + + size_t size() const { + std::lock_guard lock(mutex_); + return queue_.size(); + } + + bool empty() const { + std::lock_guard lock(mutex_); + return queue_.empty(); + } + + void close() { + { + std::lock_guard lock(mutex_); + closed_ = true; + } + cv_.notify_all(); + } + + private: + std::queue queue_; + mutable std::mutex mutex_; + std::condition_variable cv_; + bool closed_{false}; +}; + +} // namespace utils diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index d3adec7b7..abc13da9d 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -771,6 +771,10 @@ class EAGLEWorker(TpModelWorker): return score_list, token_list, parents_list + def clear_cache_pool(self): + self.model_runner.req_to_token_pool.clear() + self.model_runner.token_to_kv_pool_allocator.clear() + def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput): spec_info.prepare_for_verify(batch, self.page_size) batch.return_hidden_states = False diff --git a/python/sglang/srt/speculative/lookahead_utils.py b/python/sglang/srt/speculative/lookahead_utils.py new file mode 100644 index 000000000..5ca6cb025 --- /dev/null +++ b/python/sglang/srt/speculative/lookahead_utils.py @@ -0,0 +1,412 @@ +from __future__ import annotations + +import copy +import logging +from typing import Optional + +import torch +import triton + +logger = logging.getLogger(__name__) + +from dataclasses import dataclass + +import torch.nn.functional as F + +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import apply_custom_logit_processor +from sglang.srt.managers.schedule_batch import ( + ScheduleBatch, + get_last_loc, + global_server_args_dict, +) +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +from sglang.srt.speculative.eagle_utils import ( + TREE_SPEC_KERNEL_AVAILABLE, + assign_req_to_token_pool, + create_flashinfer_kv_indices_triton, + get_src_tgt_cache_loc, + get_target_cache_loc, +) +from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 + +if is_cuda(): + from sgl_kernel import ( + top_k_renorm_prob, + top_p_renorm_prob, + tree_speculative_sampling_target_only, + verify_tree_greedy, + ) +elif is_hip(): + from sgl_kernel import verify_tree_greedy + + +@dataclass +class LookaheadVerifyInput: + def __init__( + self, + draft_token: torch.Tensor, + tree_mask: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_next_token: torch.Tensor, + retrive_next_sibling: torch.Tensor, + draft_token_num: int, + ): + self.draft_token = draft_token + self.custom_mask = tree_mask + self.positions = positions + self.retrive_index = retrive_index + self.retrive_next_token = retrive_next_token + self.retrive_next_sibling = retrive_next_sibling + self.draft_token_num = draft_token_num + self.device = self.custom_mask.device + + def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): + if batch.forward_mode.is_idle(): + return + + batch.input_ids = self.draft_token + + if page_size == 1: + batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids)) + end_offset = batch.seq_lens + self.draft_token_num + else: + prefix_lens = batch.seq_lens + end_offset = prefix_lens + self.draft_token_num + last_loc = get_last_loc( + batch.req_to_token_pool.req_to_token, + batch.req_pool_indices, + prefix_lens, + ) + batch.out_cache_loc = batch.alloc_paged_token_slots_extend( + prefix_lens, end_offset, last_loc, len(batch.input_ids) + ) + self.last_loc = last_loc + + bs = batch.batch_size() + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + end_offset, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + bs = len(req_pool_indices) + + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device) + + paged_kernel_lens = paged_kernel_lens + self.draft_token_num + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) + + self.qo_indptr = ( + torch.arange(0, bs + 1, dtype=torch.int32, device=self.device) + * self.draft_token_num + ) + + kv_indices = torch.empty( + cum_kv_seq_len[-1], dtype=torch.int32, device=self.device + ) + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask + + def _fill_requests( + self, + batch: ScheduleBatch, + logits_output: torch.Tensor, + ): + accept_index_cpu = self.accept_index.tolist() + predict_cpu = self.predict.tolist() + has_finished = False + + # Iterate every accepted token and check if req has finished after append the token + # should be checked BEFORE free kv cache slots + for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)): + for j, idx in enumerate(accept_index_row): + if idx == -1: + break + id = predict_cpu[idx] + req.output_ids.append(id) + req.check_finished() + if req.finished(): + has_finished = True + # set all tokens after finished token to -1 and break + self.accept_index[i, j + 1 :] = -1 + break + else: + if req.grammar is not None: + try: + req.grammar.accept_token(id) + except ValueError as e: + logger.info( + f"{i=}, {req=}\n" + f"{self.accept_index=}\n" + f"{self.predict=}\n" + ) + raise e + req.spec_verify_ct += 1 + if has_finished: + self.accept_length = (self.accept_index != -1).sum(dim=1) - 1 + self.accept_index = self.accept_index[self.accept_index != -1] + + logits_output.next_token_logits = logits_output.next_token_logits[ + self.accept_index + ] + if logits_output.hidden_states: + logits_output.hidden_states = logits_output.hidden_states[self.accept_index] + self.verified_id = self.predict[self.accept_index] + + def _free_cache(self, batch: ScheduleBatch, page_size: int): + bs = batch.batch_size() + # Free the KV cache for unaccepted tokens + if page_size == 1: + # TODO: boolean array index leads to a device sync. Remove it. + evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool) + evict_mask[self.accept_index] = False + batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask]) + batch.out_cache_loc = batch.out_cache_loc[self.accept_index] + else: + # Shift the accepted tokens to the beginning. + # Only evict the last part + src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc( + batch.seq_lens, + batch.out_cache_loc, + self.accept_index, + self.accept_length, + self.draft_token_num, + page_size, + ) + to_free_slots = torch.empty( + (to_free_num_slots.sum().item(),), + dtype=torch.int64, + device=to_free_num_slots.device, + ) + + # out_cache_loc: [0 1 2, 3 4 5, 6 7 8] + # accept_index: [0 -1 2, 3 4 -1, 6 -1 -1] + # tgt_cache_loc: [0 1 , 3 4 , 6 ] + # to_free_slots: [ 2, 5, 7 8] + # to_free_slots also needs to be page-aligned without the first partial page + # + # split each row of out_cache_loc into two parts. + # 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1 + # 2. the second part goes to to_free_slots. + get_target_cache_loc[(bs,)]( + tgt_cache_loc, + to_free_slots, + self.accept_length, + to_free_num_slots, + batch.out_cache_loc, + self.draft_token_num, + next_power_of_2(self.draft_token_num), + next_power_of_2(bs), + ) + + # Free the kv cache + batch.token_to_kv_pool_allocator.free(to_free_slots) + + # Copy the kv cache + batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache( + tgt_cache_loc, src_cache_loc + ) + batch.out_cache_loc = tgt_cache_loc + + assign_req_to_token_pool[(bs,)]( + batch.req_pool_indices, + batch.req_to_token_pool.req_to_token, + batch.seq_lens, + batch.seq_lens + self.accept_length + 1, + batch.out_cache_loc, + batch.req_to_token_pool.req_to_token.shape[1], + triton.next_power_of_2(bs), + ) + + def _greedy_verify( + self, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + ): + bs = batch.batch_size() + target_predict = torch.argmax(logits_output.next_token_logits, dim=-1) + target_predict = target_predict.reshape(bs, self.draft_token_num) + + candidates = self.draft_token.reshape(bs, self.draft_token_num) + predict_shape = list(logits_output.next_token_logits.shape)[:-1] + predict_shape[-1] += 1 + self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device) + self.accept_index = torch.full( + (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device + ) + self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device) + + verify_tree_greedy( + predicts=self.predict, # mutable + accept_index=self.accept_index, # mutable + accept_token_num=self.accept_length, # mutable + candidates=candidates, + retrive_index=self.retrive_index, + retrive_next_token=self.retrive_next_token, + retrive_next_sibling=self.retrive_next_sibling, + target_predict=target_predict, + ) + + def _sampling_verify( + self, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + sampling_info: SamplingBatchInfo, + ): + bs = batch.batch_size() + candidates = self.draft_token.reshape(bs, self.draft_token_num) + predict_shape = list(logits_output.next_token_logits.shape)[:-1] + predict_shape[-1] += 1 + self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device) + self.accept_index = torch.full( + (bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device + ) + self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device) + # apply temperature and get target probs + expanded_temperature = torch.repeat_interleave( + sampling_info.temperatures, self.draft_token_num, dim=0 + ) # (bs * draft_token_num, 1) + + target_probs = F.softmax( + logits_output.next_token_logits / expanded_temperature, dim=-1 + ) # (bs * draft_token_num, vocab_size) + + # NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors + # contributing to the poor performance of _sampling_verify. + target_probs = top_k_renorm_prob( + target_probs, + torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0), + ) # (bs * draft_token_num, vocab_size) + + if sampling_info.need_top_p_sampling: + # logger.info("Using top-p sampling in speculative decoding verification.") + target_probs = top_p_renorm_prob( + target_probs, + torch.repeat_interleave( + sampling_info.top_ps, self.draft_token_num, dim=0 + ), + ) + + target_probs = target_probs.reshape(bs, self.draft_token_num, -1) + draft_probs = torch.zeros( + target_probs.shape, dtype=torch.float32, device=self.device + ) + + # coins for rejection sampling + coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device) + # coins for final sampling + coins_for_final_sampling = torch.rand( + (bs,), dtype=torch.float32, device=self.device + ) + tree_speculative_sampling_target_only( + predicts=self.predict, # mutable + accept_index=self.accept_index, # mutable + accept_token_num=self.accept_length, # mutable + candidates=candidates.to(torch.int64), + retrive_index=self.retrive_index.to(torch.int64), + retrive_next_token=self.retrive_next_token.to(torch.int64), + retrive_next_sibling=self.retrive_next_sibling.to(torch.int64), + uniform_samples=coins, + uniform_samples_for_final_sampling=coins_for_final_sampling, + target_probs=target_probs, + draft_probs=draft_probs, + threshold_single=global_server_args_dict[ + "speculative_accept_threshold_single" + ], + threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"], + deterministic=True, + ) + + def verify( + self, + batch: ScheduleBatch, + logits_output: LogitsProcessorOutput, + page_size: int, + vocab_mask: Optional[torch.Tensor] = None, # For grammar + ) -> torch.Tensor: + bs = self.retrive_index.shape[0] + sampling_info = batch.sampling_info + + if bs != len(sampling_info): + sampling_info = copy.deepcopy(sampling_info) + # NOTE: retrive_index are the indices of the requests that are kept. + sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index) + + # Apply the custom logit processors if registered in the sampling info. + if sampling_info.has_custom_logit_processor: + apply_custom_logit_processor( + logits_output.next_token_logits, + sampling_info, + num_tokens_in_batch=self.draft_token_num, + ) + + # Apply penalty + if sampling_info.penalizer_orchestrator.is_required: + # This is a relaxed version of penalties for speculative decoding. + linear_penalty = torch.zeros( + (bs, logits_output.next_token_logits.shape[1]), + dtype=torch.float32, + device=self.device, + ) + sampling_info.apply_logits_bias(linear_penalty) + logits_output.next_token_logits.add_( + torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0) + ) + + # Apply grammar mask + if vocab_mask is not None: + assert self.grammar is not None + self.grammar.apply_vocab_mask( + logits=logits_output.next_token_logits, vocab_mask=vocab_mask + ) + + # Sample tokens. Force greedy sampling on AMD + is_all_greedy = sampling_info.is_all_greedy + if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE): + logger.warning( + "Tree speculative sampling kernel unavailable (likely AMD/HIP build). " + "Falling back to greedy verification." + ) + + if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE: + self._greedy_verify(batch, logits_output) + else: + # NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor. + self._greedy_verify(batch, logits_output) + # self._sampling_verify(batch, logits_output, sampling_info) + + self._fill_requests(batch, logits_output) + self._free_cache(batch, page_size) + + batch.seq_lens.add_(self.accept_length + 1) + batch.seq_lens_sum = torch.sum(batch.seq_lens).item() + + return logits_output, self.verified_id, self.accept_length.sum().item() + + def filter_batch(self, new_indices: torch.Tensor): + pass + + def merge_batch(self, spec_info: LookaheadVerifyInput): + pass diff --git a/python/sglang/srt/speculative/lookahead_worker.py b/python/sglang/srt/speculative/lookahead_worker.py new file mode 100644 index 000000000..040078ac7 --- /dev/null +++ b/python/sglang/srt/speculative/lookahead_worker.py @@ -0,0 +1,244 @@ +import logging +import os +import threading +import time +from typing import TYPE_CHECKING, List, Optional, Union + +import numpy as np +import torch +from sgl_kernel.speculative import reconstruct_indices_from_tree_mask + +from sglang.srt.managers.schedule_batch import ScheduleBatch +from sglang.srt.managers.tp_worker import TpModelWorker +from sglang.srt.model_executor.forward_batch_info import ForwardMode +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache +from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.utils import broadcast_pyobj + +logger = logging.getLogger(__name__) + +USE_FULL_MASK = True + + +class LOOKAHEADWorker: + def __init__( + self, + server_args: ServerArgs, + gpu_id: int, + tp_rank: int, + dp_rank: Optional[int], + moe_ep_rank: int, + nccl_port: int, + target_worker: TpModelWorker, + ): + self.target_worker = target_worker + self.model_runner = target_worker.model_runner + self.tp_rank = tp_rank + self.page_size = server_args.page_size + self.draft_token_num: int = server_args.speculative_num_draft_tokens + self.branch_length: int = server_args.speculative_lookahead_branch_length + self.max_match_window_size: int = ( + server_args.speculative_lookahead_max_match_window_size + ) + + self.max_batch_size = target_worker.max_running_requests + self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda" + + self._init_preallocated_tensors() + + self.lookahead_cache = LookaheadCache( + min_match_window_size=server_args.speculative_lookahead_min_match_window_size, + max_match_window_size=server_args.speculative_lookahead_max_match_window_size, + min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth, + max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth, + capacity=server_args.speculative_lookahead_capacity, + branch_length=server_args.speculative_lookahead_branch_length, + draft_token_num=server_args.speculative_num_draft_tokens, + ) + + def clear_cache_pool(self): + self.lookahead_cache.reset() + + def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int): + seq2_len = len(seq2) + if seq2_len >= n: + return seq2[-n:] + + need_from_seq1 = n - seq2_len + return seq1[-need_from_seq1:] + seq2 + + def _init_preallocated_tensors(self): + max_total_drafts = self.max_batch_size * self.draft_token_num + max_total_mask_size = ( + self.max_batch_size * self.draft_token_num * self.draft_token_num + ) + + self.draft_tokens = torch.empty( + (max_total_drafts,), dtype=torch.int64, device=self.device + ) + self.retrieve_indexes = torch.empty( + (self.max_batch_size, self.draft_token_num), + dtype=torch.int64, + device=self.device, + ) + self.retrive_next_token = torch.empty( + (self.max_batch_size, self.draft_token_num), + dtype=torch.int64, + device=self.device, + ) + self.retrive_next_sibling = torch.empty( + (self.max_batch_size, self.draft_token_num), + dtype=torch.int64, + device=self.device, + ) + self.positions = torch.empty( + (max_total_drafts,), dtype=torch.int64, device=self.device + ) + self.tree_mask = torch.empty( + (max_total_mask_size,), dtype=torch.bool, device=self.device + ) + + self.draft_tokens_batch = [] + self.tree_mask_batch = [] + self.retrieve_indexes_batch = [] + self.retrive_next_token_batch = [] + self.retrive_next_sibling_batch = [] + self.positions_batch = [] + + for bs in range(0, self.max_batch_size + 1): + self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :]) + self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :]) + self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :]) + self.positions_batch.append(self.positions[: bs * self.draft_token_num]) + self.draft_tokens_batch.append( + self.draft_tokens[: bs * self.draft_token_num] + ) + self.tree_mask_batch.append( + self.tree_mask[: bs * self.draft_token_num * self.draft_token_num] + ) + + def _prepare_draft_tokens( + self, batch: ScheduleBatch + ) -> tuple[np.ndarray, np.ndarray]: + bs = batch.batch_size() + + self.lookahead_cache.synchronize() + batch_tokens = [] + for req in batch.reqs: + check_token = self._efficient_concat_last_n( + req.origin_input_ids, req.output_ids, self.max_match_window_size + ) + batch_tokens.append(check_token) + req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens) + total_draft_token_num = len(req_drafts) + + # Check if speculative decoding is needed; here we always enforce it + assert ( + total_draft_token_num == bs * self.draft_token_num + ), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}" + return req_drafts, mask + + def _prepare_for_speculative_decoding(self, batch: ScheduleBatch): + if batch.forward_mode.is_extend(): + return + + bs = batch.batch_size() + + retrive_index = self.retrieve_indexes_batch[bs] + retrive_next_token = self.retrive_next_token_batch[bs] + retrive_next_sibling = self.retrive_next_sibling_batch[bs] + positions = self.positions_batch[bs] + tree_mask = self.tree_mask_batch[bs] + draft_tokens = self.draft_tokens_batch[bs] + + req_drafts, mask = self._prepare_draft_tokens(batch) + tree_mask.copy_(torch.from_numpy(mask), non_blocking=True) + draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True) + + reconstruct_indices_from_tree_mask( + tree_mask, + batch.seq_lens, + positions, # mutable + retrive_index, # mutable + retrive_next_token, # mutable + retrive_next_sibling, # mutable + bs, + self.draft_token_num, + ) + + # NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer. + # Testing shows about 8% performance improvement (the effect is roughly proportional to batch size). + if USE_FULL_MASK: + tree_mask = [] + mask = mask.reshape( + batch.batch_size(), self.draft_token_num, self.draft_token_num + ) + for i, req in enumerate(batch.reqs): + seq_len = len(req.origin_input_ids) + len(req.output_ids) + req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda() + req_mask = torch.cat( + (req_mask, torch.from_numpy(mask[i]).cuda()), dim=1 + ).to(torch.bool) + tree_mask.append(req_mask.flatten()) + tree_mask = torch.cat(tree_mask, dim=0) + + batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD + batch.forward_mode = ForwardMode.TARGET_VERIFY + batch.spec_info = LookaheadVerifyInput( + draft_tokens, + tree_mask, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + self.draft_token_num, + ) + batch.spec_info.prepare_for_verify(batch, self.page_size) + + def _update_lookahead_cache(self, batch: ScheduleBatch): + batch_tokens = [] + for req in batch.reqs: + # FIXME: Whether to insert 'extend' into the cache or not, after testing, + # there is not much difference, so we will not insert it for now. + # if batch.forward_mode.is_extend(): + # put_ids = req.origin_input_ids + req.output_ids + # else: + put_ids = self._efficient_concat_last_n( + req.origin_input_ids, req.output_ids, self.branch_length + ) + batch_tokens.append(put_ids) + self.lookahead_cache.batch_put(batch_tokens) + + def forward_batch_speculative_generation(self, batch: ScheduleBatch): + self._prepare_for_speculative_decoding(batch) + model_worker_batch = batch.get_model_worker_batch() + bid = model_worker_batch.bid + num_accepted_tokens = 0 + + if model_worker_batch.forward_mode.is_target_verify(): + logits_output, _, can_run_cuda_graph = ( + self.target_worker.forward_batch_generation( + model_worker_batch, skip_sample=True + ) + ) + verify_input = model_worker_batch.spec_info + logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( + batch, logits_output, self.page_size + ) + self._update_lookahead_cache(batch) + batch.forward_mode = ForwardMode.DECODE + + else: + logits_output, next_token_ids, can_run_cuda_graph = ( + self.target_worker.forward_batch_generation(model_worker_batch) + ) + + return ( + logits_output, + next_token_ids, + bid, + num_accepted_tokens, + can_run_cuda_graph, + ) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a80963471..a865d0ff6 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -6,6 +6,7 @@ class SpeculativeAlgorithm(IntEnum): EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() + LOOKAHEAD = auto() def is_none(self): return self == SpeculativeAlgorithm.NONE @@ -19,12 +20,16 @@ class SpeculativeAlgorithm(IntEnum): def is_standalone(self): return self == SpeculativeAlgorithm.STANDALONE + def is_lookahead(self): + return self == SpeculativeAlgorithm.LOOKAHEAD + @staticmethod def from_string(name: str): name_map = { "EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE3": SpeculativeAlgorithm.EAGLE3, "STANDALONE": SpeculativeAlgorithm.STANDALONE, + "LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 6a69fed00..208b45578 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -80,6 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( "meta-llama/Llama-3.1-8B-Instruct" ) DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" +DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct" # Other use cases DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 3a1c64605..9b73c048a 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -318,6 +318,7 @@ set(SOURCES "csrc/kvcacheio/transfer.cu" "csrc/speculative/eagle_utils.cu" + "csrc/speculative/lookahead_utils.cu" "csrc/speculative/packbit.cu" "csrc/speculative/speculative_sampling.cu" diff --git a/sgl-kernel/csrc/common_extension.cc b/sgl-kernel/csrc/common_extension.cc index 9c89bcad7..21f3763f6 100644 --- a/sgl-kernel/csrc/common_extension.cc +++ b/sgl-kernel/csrc/common_extension.cc @@ -291,6 +291,12 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { "Tensor target_predict, int cuda_stream) -> ()"); m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); + m.def( + "reconstruct_indices_from_tree_mask(Tensor tree_mask, Tensor verified_seq_len, Tensor positions, " + "Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "int batch_size, int draft_token_num) -> ()"); + m.impl("reconstruct_indices_from_tree_mask", torch::kCUDA, &reconstruct_indices_from_tree_mask); + m.def( "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " diff --git a/sgl-kernel/csrc/speculative/lookahead_utils.cu b/sgl-kernel/csrc/speculative/lookahead_utils.cu new file mode 100644 index 000000000..b51054222 --- /dev/null +++ b/sgl-kernel/csrc/speculative/lookahead_utils.cu @@ -0,0 +1,105 @@ +#include +#include + +#ifndef USE_ROCM +#include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#endif + +// tree_mask: [bs * draft_token_num * draft_token_num] +// verified_seq_len: [bs] +// positions: [bs * draft_token_num] +// retrive_index: [bs, draft_token_num] +// retrive_next_token: [bs, draft_token_num] +// retrive_next_sibling: [bs, draft_token_num] +__global__ void reconstructIndicesFromTreeMask( + bool* tree_mask, + int64_t* verified_seq_len, + int64_t* positions, + int64_t* retrive_index, + int64_t* retrive_next_token, + int64_t* retrive_next_sibling, + int batch_size, + int draft_token_num) { + int bid = blockIdx.x; + int tid = threadIdx.x; + + if (bid >= batch_size || tid >= draft_token_num) { + return; + } + int base_offset = draft_token_num * draft_token_num; + // token_idx: [bid * draft_token_num, (bid + 1) * draft_token_num) + int token_idx = bid * draft_token_num; + // tree_mask_idx: [bid * base_offset, (bid + 1) * base_offset) + int tree_mask_offset = bid * base_offset; + + int depth = 0; + int parent_idx = -1; + + for (int i = tid - 1, start_idx = tree_mask_offset + tid * draft_token_num; i >= 0; i--) { + if (tree_mask[start_idx + i]) { + depth++; + if (parent_idx == -1) { + parent_idx = i; + } + } + } + retrive_index[token_idx + tid] = token_idx + tid; + positions[token_idx + tid] = depth + verified_seq_len[bid]; + + int next_token_idx = -1; + for (int i = tid + 1; i < draft_token_num; i++) { + if (tree_mask[tree_mask_offset + i * draft_token_num + tid]) { + next_token_idx = i; + break; + } + } + retrive_next_token[token_idx + tid] = next_token_idx; + + int next_sibling_idx = -1; + if (parent_idx != -1) { + for (int i = tid + 1; i < draft_token_num; i++) { + int start_idx = tree_mask_offset + i * draft_token_num + parent_idx; + if (tree_mask[start_idx]) { + bool is_sibling = true; + int end_idx = tree_mask_offset + i * draft_token_num + i; + for (int j = start_idx + 1; j < end_idx; ++j) { + if (tree_mask[j]) { + is_sibling = false; + break; + } + } + if (is_sibling) { + next_sibling_idx = i; + break; + } + } + } + } + retrive_next_sibling[token_idx + tid] = next_sibling_idx; +} + +void reconstruct_indices_from_tree_mask( + at::Tensor tree_mask, + at::Tensor verified_seq_len, + at::Tensor positions, + at::Tensor retrive_index, + at::Tensor retrive_next_token, + at::Tensor retrive_next_sibling, + int64_t batch_size, + int64_t draft_token_num) { + dim3 grid(batch_size); + dim3 block(draft_token_num); + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + reconstructIndicesFromTreeMask<<>>( + static_cast(tree_mask.data_ptr()), + static_cast(verified_seq_len.data_ptr()), + static_cast(positions.data_ptr()), + static_cast(retrive_index.data_ptr()), + static_cast(retrive_next_token.data_ptr()), + static_cast(retrive_next_sibling.data_ptr()), + int(batch_size), + int(draft_token_num)); +} diff --git a/sgl-kernel/include/sgl_kernel_ops.h b/sgl-kernel/include/sgl_kernel_ops.h index c166319c0..5829a72e4 100644 --- a/sgl-kernel/include/sgl_kernel_ops.h +++ b/sgl-kernel/include/sgl_kernel_ops.h @@ -457,6 +457,16 @@ void verify_tree_greedy( at::Tensor target_predict, int64_t cuda_stream = 0); +void reconstruct_indices_from_tree_mask( + at::Tensor tree_mask, + at::Tensor verified_seq_len, + at::Tensor positions, // mutable + at::Tensor retrive_index, // mutable + at::Tensor retrive_next_token, // mutable + at::Tensor retrive_next_sibling, // mutable + int64_t batch_size, + int64_t draft_token_num); + void build_tree_kernel_efficient( at::Tensor parent_list, at::Tensor selected_index, diff --git a/sgl-kernel/python/sgl_kernel/__init__.py b/sgl-kernel/python/sgl_kernel/__init__.py index 37ba4e3a9..e6f8c0dc6 100644 --- a/sgl-kernel/python/sgl_kernel/__init__.py +++ b/sgl-kernel/python/sgl_kernel/__init__.py @@ -126,6 +126,7 @@ from sgl_kernel.sampling import ( ) from sgl_kernel.speculative import ( build_tree_kernel_efficient, + reconstruct_indices_from_tree_mask, segment_packbits, tree_speculative_sampling_target_only, verify_tree_greedy, diff --git a/sgl-kernel/python/sgl_kernel/speculative.py b/sgl-kernel/python/sgl_kernel/speculative.py index ea2e3ac8a..ee4bf7fb0 100644 --- a/sgl-kernel/python/sgl_kernel/speculative.py +++ b/sgl-kernel/python/sgl_kernel/speculative.py @@ -90,6 +90,28 @@ def build_tree_kernel_efficient( ) +def reconstruct_indices_from_tree_mask( + tree_mask: torch.Tensor, + verified_seq_len: torch.Tensor, + positions: torch.Tensor, + retrive_index: torch.Tensor, + retrive_next_token: torch.Tensor, + retrive_next_sibling: torch.Tensor, + batch_size: int, + draft_token_num: int, +) -> None: + torch.ops.sgl_kernel.reconstruct_indices_from_tree_mask.default( + tree_mask, + verified_seq_len, + positions, + retrive_index, + retrive_next_token, + retrive_next_sibling, + batch_size, + draft_token_num, + ) + + def segment_packbits( x: torch.Tensor, input_indptr: torch.Tensor, diff --git a/sgl-kernel/tests/speculative/test_lookahead_utils.py b/sgl-kernel/tests/speculative/test_lookahead_utils.py new file mode 100644 index 000000000..29bf89f93 --- /dev/null +++ b/sgl-kernel/tests/speculative/test_lookahead_utils.py @@ -0,0 +1,76 @@ +import pytest +import torch +import torch.nn.functional as F +from sgl_kernel import reconstruct_indices_from_tree_mask + + +def test_reconstruct_indices_from_tree_mask(): + bs = 1 + num_branch_token = 4 + seq_lens = torch.tensor([12], device="cuda", dtype=torch.int64) + + retrive_index = torch.full( + (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 + ) + retrive_next_token = torch.full( + (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 + ) + retrive_next_sibling = torch.full( + (bs, num_branch_token), -1, device="cuda", dtype=torch.int64 + ) + positions = torch.empty((bs * num_branch_token), device="cuda", dtype=torch.int64) + + tree_mask = torch.tensor( + [ + 1, + 0, + 0, + 0, + 1, + 1, + 0, + 0, + 1, + 0, + 1, + 0, + 1, + 0, + 1, + 1, + ], + device="cuda", + dtype=torch.int32, + ).to(torch.bool) + + reconstruct_indices_from_tree_mask( + tree_mask, + seq_lens, + positions, # mutable + retrive_index, # mutable + retrive_next_token, # mutable + retrive_next_sibling, # mutable + bs, + num_branch_token, + ) + # print(f"debug: \n\n{tree_mask=}, {retrive_index=}, {retrive_next_token=}, {retrive_next_sibling=}, {positions=}\n\n") + assert retrive_index.tolist() == [ + [0, 1, 2, 3], + ], f"{retrive_index=}" + assert retrive_next_token.tolist() == [ + [1, -1, 3, -1], + ], f"{retrive_next_token=}" + assert retrive_next_sibling.tolist() == [ + [-1, 2, -1, -1], + ], f"{retrive_next_sibling=}" + assert positions.tolist() == [ + 12, + 13, + 13, + 14, + ], f"{positions=}" + + +if __name__ == "__main__": + test_reconstruct_indices_from_tree_mask() + pytest.main([__file__]) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 13d1000a3..a5c449d00 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -78,6 +78,7 @@ suites = { TestFile("test_hidden_states.py", 55), TestFile("test_hybrid_attn_backend.py", 100), TestFile("test_standalone_speculative_decoding.py", 250), + TestFile("test_lookahead_speculative_decoding.py", 250), TestFile("test_input_embeddings.py", 38), TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), diff --git a/test/srt/test_lookahead_speculative_decoding.py b/test/srt/test_lookahead_speculative_decoding.py new file mode 100644 index 000000000..b0e7da529 --- /dev/null +++ b/test/srt/test_lookahead_speculative_decoding.py @@ -0,0 +1,107 @@ +import os +import unittest +from types import SimpleNamespace + +import requests + +from sglang.srt.utils import kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + CustomTestCase, + popen_launch_server, +) + +GSM_DATASET_PATH = None + + +# Default server arguments shared across all tests +DEFAULT_SERVER_ARGS = [ + "--trust-remote-code", + "--cuda-graph-max-bs", + "8", + "--speculative-algorithm", + "LOOKAHEAD", + "--speculative-num-draft-tokens", + "16", + "--mem-fraction-static", + 0.8, +] + + +class TestStandaloneSpeculativeDecodingBase(CustomTestCase): + + model = DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST + base_url = DEFAULT_URL_FOR_TEST + accuracy_threshold = 0.8 # derived tests need to override this + spec_decode_threshold = 1.8 # derived spec decoding tests need to override this + + @classmethod + def get_server_args(cls): + """Return the arguments for the server launch. Override in subclasses.""" + return DEFAULT_SERVER_ARGS + ["--attention-backend", "fa3"] + + @classmethod + def setUpClass(cls): + # disable deep gemm precompile to make launch server faster + # please don't do this if you want to make your inference workload faster + os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false" + os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false" + model = cls.model + cls.process = popen_launch_server( + model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=cls.get_server_args(), + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + requests.get(self.base_url + "/flush_cache") + + args = SimpleNamespace( + num_shots=4, + num_questions=100, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + data_path=GSM_DATASET_PATH, + ) + metrics = run_eval_few_shot_gsm8k(args) + print(f"{metrics=}") + + # Use the appropriate metric key based on the test class + metric_key = "accuracy" + self.assertGreater(metrics[metric_key], self.accuracy_threshold) + + server_info = requests.get(self.base_url + "/get_server_info") + avg_spec_accept_length = server_info.json()["internal_states"][0][ + "avg_spec_accept_length" + ] + print(f"{avg_spec_accept_length=}") + self.assertGreater(avg_spec_accept_length, self.spec_decode_threshold) + + +class TestStandaloneSpeculativeDecodingTriton(TestStandaloneSpeculativeDecodingBase): + + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + ["--attention-backend", "triton"] + + +class TestStandaloneSpeculativeDecodingFlashinfer( + TestStandaloneSpeculativeDecodingBase +): + @classmethod + def get_server_args(cls): + return DEFAULT_SERVER_ARGS + ["--attention-backend", "flashinfer"] + + +if __name__ == "__main__": + unittest.main()