From 24f7cb1ece69d079dd2b1486b480f5169d8ff824 Mon Sep 17 00:00:00 2001 From: Zhihao Zhang <44469583+a4zhangfei@users.noreply.github.com> Date: Mon, 29 Sep 2025 12:06:59 +0800 Subject: [PATCH] [speculative decoding] rename lookahead to ngram (#11010) Co-authored-by: a4zhangfei --- python/pyproject.toml | 4 +- .../layers/attention/flashinfer_backend.py | 52 ++++--------- python/sglang/srt/managers/schedule_batch.py | 16 ++-- python/sglang/srt/managers/scheduler.py | 10 +-- .../srt/model_executor/cuda_graph_runner.py | 14 ++-- python/sglang/srt/server_args.py | 76 +++++++++---------- .../.clang-format | 0 .../lookahead.cpp => cpp_ngram/ngram.cpp} | 37 +++++---- .../lookahead.h => cpp_ngram/ngram.h} | 16 ++-- .../ngram_cache.py} | 18 ++--- .../ngram_cache_binding.cpp} | 24 +++--- .../{cpp_lookahead => cpp_ngram}/param.h | 4 +- .../{cpp_lookahead => cpp_ngram}/queue.h | 0 .../{lookahead_utils.py => ngram_utils.py} | 4 +- .../{lookahead_worker.py => ngram_worker.py} | 40 +++++----- python/sglang/srt/speculative/spec_info.py | 8 +- python/sglang/test/test_utils.py | 2 +- sgl-kernel/CMakeLists.txt | 2 +- .../{lookahead_utils.cu => ngram_utils.cu} | 0 ...lookahead_utils.py => test_ngram_utils.py} | 0 test/srt/run_suite.py | 2 +- ....py => test_ngram_speculative_decoding.py} | 6 +- 22 files changed, 154 insertions(+), 181 deletions(-) rename python/sglang/srt/speculative/{cpp_lookahead => cpp_ngram}/.clang-format (100%) rename python/sglang/srt/speculative/{cpp_lookahead/lookahead.cpp => cpp_ngram/ngram.cpp} (91%) rename python/sglang/srt/speculative/{cpp_lookahead/lookahead.h => cpp_ngram/ngram.h} (91%) rename python/sglang/srt/speculative/{cpp_lookahead/lookahead_cache.py => cpp_ngram/ngram_cache.py} (91%) rename python/sglang/srt/speculative/{cpp_lookahead/lookahead_cache_binding.cpp => cpp_ngram/ngram_cache_binding.cpp} (71%) rename python/sglang/srt/speculative/{cpp_lookahead => cpp_ngram}/param.h (98%) rename python/sglang/srt/speculative/{cpp_lookahead => cpp_ngram}/queue.h (100%) rename python/sglang/srt/speculative/{lookahead_utils.py => ngram_utils.py} (99%) rename python/sglang/srt/speculative/{lookahead_worker.py => ngram_worker.py} (86%) rename sgl-kernel/csrc/speculative/{lookahead_utils.cu => ngram_utils.cu} (100%) rename sgl-kernel/tests/speculative/{test_lookahead_utils.py => test_ngram_utils.py} (100%) rename test/srt/{test_lookahead_speculative_decoding.py => test_ngram_speculative_decoding.py} (95%) diff --git a/python/pyproject.toml b/python/pyproject.toml index 11f076d24..f69934d26 100755 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"] "srt/layers/moe/fused_moe_triton/configs/*/*.json", "srt/layers/quantization/configs/*.json", "srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp", - "srt/speculative/cpp_lookahead/*.cpp", - "srt/speculative/cpp_lookahead/*.h", + "srt/speculative/cpp_ngram/*.cpp", + "srt/speculative/cpp_ngram/*.h", ] [tool.setuptools.packages.find] diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index aaa8b520b..2b69d734c 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput +from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.utils import ( get_int_env_var, is_flashinfer_available, @@ -344,9 +344,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -453,9 +451,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): @@ -673,9 +669,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -690,9 +684,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -718,9 +710,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -770,9 +760,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -806,9 +794,7 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], seq_lens_cpu: Optional[torch.Tensor], use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, @@ -919,9 +905,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, ): # Keep the signature for type checking. It will be assigned during runtime. @@ -937,9 +921,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, ): if use_ragged: @@ -977,9 +959,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): @@ -1026,9 +1006,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): @@ -1071,9 +1049,7 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ], + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, ): @@ -1102,7 +1078,7 @@ class FlashInferIndicesUpdaterPrefill: custom_mask = None else: assert isinstance( - spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput) + spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput) ) kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 6457307c1..2efd0de92 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -74,7 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput - from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput + from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -953,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[ - Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput] - ] = None + spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = ( + None + ) # Whether to return hidden states return_hidden_states: bool = False @@ -1608,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): if ( self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone() - or self.spec_algorithm.is_lookahead() + or self.spec_algorithm.is_ngram() ): # if spec decoding is used, the decode batch is prepared inside # `forward_batch_speculative_generation` after running draft models. @@ -1984,9 +1984,9 @@ class ModelWorkerBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[ - Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput] - ] = None + spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = ( + None + ) # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = -1 diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 94cd8e16f..893a0b0a1 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -388,10 +388,10 @@ class Scheduler( target_worker=self.tp_worker, dp_rank=dp_rank, ) - elif self.spec_algorithm.is_lookahead(): - from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker + elif self.spec_algorithm.is_ngram(): + from sglang.srt.speculative.ngram_worker import NGRAMWorker - self.draft_worker = LOOKAHEADWorker( + self.draft_worker = NGRAMWorker( gpu_id=gpu_id, tp_rank=tp_rank, moe_ep_rank=moe_ep_rank, @@ -826,7 +826,7 @@ class Scheduler( token_to_kv_pool_allocator=self.token_to_kv_pool_allocator, draft_token_to_kv_pool=( None - if self.draft_worker is None or self.spec_algorithm.is_lookahead() + if self.draft_worker is None or self.spec_algorithm.is_ngram() else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, @@ -863,7 +863,7 @@ class Scheduler( token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(), draft_token_to_kv_pool=( None - if self.draft_worker is None or self.spec_algorithm.is_lookahead() + if self.draft_worker is None or self.spec_algorithm.is_ngram() else self.draft_worker.model_runner.token_to_kv_pool ), req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator, diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index dc51102ec..4f09e621a 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -246,7 +246,7 @@ class CudaGraphRunner: if ( model_runner.spec_algorithm.is_eagle() or model_runner.spec_algorithm.is_standalone() - or model_runner.spec_algorithm.is_lookahead() + or model_runner.spec_algorithm.is_ngram() ): if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen") @@ -413,12 +413,12 @@ class CudaGraphRunner: forward_batch.can_run_tbo if self.enable_two_batch_overlap else True ) - is_lookahead_supported = ( + is_ngram_supported = ( ( forward_batch.batch_size * self.num_tokens_per_bs == forward_batch.input_ids.numel() ) - if self.model_runner.spec_algorithm.is_lookahead() + if self.model_runner.spec_algorithm.is_ngram() else True ) @@ -427,7 +427,7 @@ class CudaGraphRunner: and is_encoder_lens_supported and is_tbo_supported and capture_hidden_mode_matches - and is_lookahead_supported + and is_ngram_supported ) def capture(self) -> None: @@ -838,10 +838,10 @@ class CudaGraphRunner: seq_lens_cpu=None, ) - elif self.model_runner.spec_algorithm.is_lookahead(): - from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput + elif self.model_runner.spec_algorithm.is_ngram(): + from sglang.srt.speculative.ngram_utils import NgramVerifyInput - spec_info = LookaheadVerifyInput( + spec_info = NgramVerifyInput( draft_token=None, tree_mask=self.custom_mask, positions=None, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index f076b8a1f..bbe96fc9b 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -286,14 +286,14 @@ class ServerArgs: speculative_accept_threshold_acc: float = 1.0 speculative_token_map: Optional[str] = None speculative_attention_mode: str = "prefill" - # For lookahead only - speculative_lookahead_min_match_window_size: int = 1 - speculative_lookahead_max_match_window_size: int = 12 - speculative_lookahead_min_bfs_breadth: int = 1 - speculative_lookahead_max_bfs_breadth: int = 10 - speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS" - speculative_lookahead_branch_length: int = 18 - speculative_lookahead_capacity: int = 10 * 1000 * 1000 + # For ngram only + speculative_ngram_min_match_window_size: int = 1 + speculative_ngram_max_match_window_size: int = 12 + speculative_ngram_min_bfs_breadth: int = 1 + speculative_ngram_max_bfs_breadth: int = 10 + speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS" + speculative_ngram_branch_length: int = 18 + speculative_ngram_capacity: int = 10 * 1000 * 1000 # Expert parallelism ep_size: int = 1 @@ -566,7 +566,7 @@ class ServerArgs: # Standalone speculative decoding needs more memory than other speculative # decoding algorithms since the draft model is typically larger. reserved_mem += 6 * 1024 - elif self.speculative_algorithm != "LOOKAHEAD": + elif self.speculative_algorithm != "NGRAM": reserved_mem += 2 * 1024 if self.enable_dp_attention: reserved_mem += 4 * 1024 @@ -1024,23 +1024,23 @@ class ServerArgs: "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." ) - if self.speculative_algorithm == "LOOKAHEAD": + if self.speculative_algorithm == "NGRAM": if not self.device.startswith("cuda"): raise ValueError( - "Lookahead speculative decoding only supports CUDA device." + "Ngram speculative decoding only supports CUDA device." ) if self.max_running_requests is None: self.max_running_requests = 48 self.disable_overlap_schedule = True self.enable_mixed_chunk = False - self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth + self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth if self.speculative_num_draft_tokens is None: self.speculative_num_draft_tokens = ( - self.speculative_lookahead_max_match_window_size + self.speculative_ngram_max_match_window_size ) logger.warning( "The overlap scheduler and mixed chunked prefill are disabled because of " - "using lookahead speculative decoding." + "using ngram speculative decoding." ) if ( @@ -1052,9 +1052,9 @@ class ServerArgs: "speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend." ) if self.enable_dp_attention: - # TODO: support dp attention for lookahead speculative decoding + # TODO: support dp attention for ngram speculative decoding raise ValueError( - "Currently lookahead speculative decoding does not support dp attention." + "Currently ngram speculative decoding does not support dp attention." ) def _handle_load_format(self): @@ -1921,7 +1921,7 @@ class ServerArgs: parser.add_argument( "--speculative-algorithm", type=str, - choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"], + choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"], help="Speculative algorithm.", ) parser.add_argument( @@ -1981,49 +1981,49 @@ class ServerArgs: help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.", default=ServerArgs.speculative_attention_mode, ) - # Lookahead speculative decoding + # Ngram speculative decoding parser.add_argument( - "--speculative-lookahead-min-match-window-size", + "--speculative-ngram-min-match-window-size", type=int, - default=ServerArgs.speculative_lookahead_min_match_window_size, - help="The minimum window size for pattern matching in lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_min_match_window_size, + help="The minimum window size for pattern matching in ngram speculative decoding.", ) parser.add_argument( - "--speculative-lookahead-max-match-window-size", + "--speculative-ngram-max-match-window-size", type=int, - default=ServerArgs.speculative_lookahead_max_match_window_size, - help="The maximum window size for pattern matching in lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_max_match_window_size, + help="The maximum window size for pattern matching in ngram speculative decoding.", ) parser.add_argument( - "--speculative-lookahead-min-bfs-breadth", + "--speculative-ngram-min-bfs-breadth", type=int, - default=ServerArgs.speculative_lookahead_min_bfs_breadth, - help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_min_bfs_breadth, + help="The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding.", ) parser.add_argument( - "--speculative-lookahead-max-bfs-breadth", + "--speculative-ngram-max-bfs-breadth", type=int, - default=ServerArgs.speculative_lookahead_max_bfs_breadth, - help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_max_bfs_breadth, + help="The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding.", ) parser.add_argument( - "--speculative-lookahead-match-type", + "--speculative-ngram-match-type", type=str, choices=["BFS", "PROB"], - default=ServerArgs.speculative_lookahead_match_type, + default=ServerArgs.speculative_ngram_match_type, help="The match type for cache tree.", ) parser.add_argument( - "--speculative-lookahead-branch-length", + "--speculative-ngram-branch-length", type=int, - default=ServerArgs.speculative_lookahead_branch_length, - help="The branch length for lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_branch_length, + help="The branch length for ngram speculative decoding.", ) parser.add_argument( - "--speculative-lookahead-capacity", + "--speculative-ngram-capacity", type=int, - default=ServerArgs.speculative_lookahead_capacity, - help="The cache capacity for lookahead speculative decoding.", + default=ServerArgs.speculative_ngram_capacity, + help="The cache capacity for ngram speculative decoding.", ) # Expert parallelism diff --git a/python/sglang/srt/speculative/cpp_lookahead/.clang-format b/python/sglang/srt/speculative/cpp_ngram/.clang-format similarity index 100% rename from python/sglang/srt/speculative/cpp_lookahead/.clang-format rename to python/sglang/srt/speculative/cpp_ngram/.clang-format diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp similarity index 91% rename from python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp rename to python/sglang/srt/speculative/cpp_ngram/ngram.cpp index c47ebcd8d..51172c5dd 100644 --- a/python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.cpp @@ -1,16 +1,16 @@ -#include "lookahead.h" +#include "ngram.h" #include #include -namespace lookahead { +namespace ngram { struct Node { std::unordered_map next; }; -Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector& tree, int root) { - Lookahead::Result info; +Ngram::Result fillResult(int last_token, int draft_token_num, std::vector& tree, int root) { + Ngram::Result info; std::vector prevs; info.token.reserve(draft_token_num); prevs.reserve(draft_token_num); @@ -50,7 +50,7 @@ Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector> -Lookahead::match(const std::vector& tokens, size_t batch_size) const { +std::vector> Ngram::match(const std::vector& tokens, size_t batch_size) const { auto draft_token_num = param_.get_draft_token_num(batch_size); auto min_match_window_size = param_.get_min_match_window_size(batch_size); auto max_match_window_size = param_.max_match_window_size; @@ -154,7 +153,7 @@ Lookahead::match(const std::vector& tokens, size_t batch_size) const { return result; } -void Lookahead::squeeze(size_t count) { +void Ngram::squeeze(size_t count) { if (!(node_pool_.size() >= free_node_count_ + count)) { throw std::runtime_error( "Insufficient node size to release required nodes. " @@ -177,13 +176,13 @@ void Lookahead::squeeze(size_t count) { } } -void Lookahead::synchronize() const { +void Ngram::synchronize() const { while (!insert_queue_.empty()) { std::this_thread::sleep_for(std::chrono::microseconds(10)); } } -void Lookahead::insert() { +void Ngram::insert() { while (!quit_flag_) { std::vector data; if (!insert_queue_.dequeue(data)) { @@ -239,13 +238,13 @@ void Lookahead::insert() { } } -void Lookahead::asyncInsert(std::vector>&& tokens) { +void Ngram::asyncInsert(std::vector>&& tokens) { for (auto&& token : tokens) { insert_queue_.enqueue(std::move(token)); } } -Lookahead::Result Lookahead::matchBFS(const std::vector& tokens, size_t batch_size) const { +Ngram::Result Ngram::matchBFS(const std::vector& tokens, size_t batch_size) const { std::vector> nodes = match(tokens, batch_size); double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) / @@ -284,7 +283,7 @@ Lookahead::Result Lookahead::matchBFS(const std::vector& tokens, size_t return fillResult(tokens.back(), draft_token_num + 1, tree, root); } -Lookahead::Result Lookahead::matchProb(const std::vector& tokens, size_t batch_size) const { +Ngram::Result Ngram::matchProb(const std::vector& tokens, size_t batch_size) const { std::vector> nodes = match(tokens, batch_size); auto draft_token_num = param_.get_draft_token_num(batch_size); @@ -346,10 +345,10 @@ Lookahead::Result Lookahead::matchProb(const std::vector& tokens, size_ return fillResult(tokens.back(), draft_token_num + 1, tree, root); } -Lookahead::Result Lookahead::batchMatch(const std::vector>& tokens) const { +Ngram::Result Ngram::batchMatch(const std::vector>& tokens) const { std::unique_lock lock(mutex_); Result merged_result; - auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb; + auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb; for (const auto& tks : tokens) { Result res = (this->*match_func)(tks, tokens.size()); merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end()); @@ -358,7 +357,7 @@ Lookahead::Result Lookahead::batchMatch(const std::vector>& return merged_result; } -void Lookahead::Result::truncate(size_t n) { +void Ngram::Result::truncate(size_t n) { if (n < token.size()) { int full_n = token.size(); for (int i = 1; i < n; ++i) { @@ -369,4 +368,4 @@ void Lookahead::Result::truncate(size_t n) { } } -} // namespace lookahead +} // namespace ngram diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead.h b/python/sglang/srt/speculative/cpp_ngram/ngram.h similarity index 91% rename from python/sglang/srt/speculative/cpp_lookahead/lookahead.h rename to python/sglang/srt/speculative/cpp_ngram/ngram.h index 9c6c82c92..bf0af0df9 100644 --- a/python/sglang/srt/speculative/cpp_lookahead/lookahead.h +++ b/python/sglang/srt/speculative/cpp_ngram/ngram.h @@ -15,7 +15,7 @@ #include "param.h" #include "queue.h" -namespace lookahead { +namespace ngram { struct TrieNode { std::unordered_map child; @@ -34,7 +34,7 @@ struct TrieNode { std::multiset sorted_children; }; -class Lookahead { +class Ngram { std::vector nodes_; std::vector node_pool_; size_t free_node_count_; @@ -61,12 +61,12 @@ class Lookahead { std::vector> match_tmp_data_; public: - Lookahead(size_t capacity, const Param& param); - Lookahead() = default; - ~Lookahead(); + Ngram(size_t capacity, const Param& param); + Ngram() = default; + ~Ngram(); - static Lookahead& instance() { - static Lookahead instance; + static Ngram& instance() { + static Ngram instance; return instance; } @@ -107,4 +107,4 @@ class Lookahead { void insert(); }; -} // namespace lookahead +} // namespace ngram diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py b/python/sglang/srt/speculative/cpp_ngram/ngram_cache.py similarity index 91% rename from python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py rename to python/sglang/srt/speculative/cpp_ngram/ngram_cache.py index 871b60878..8b1eb8eea 100644 --- a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_cache.py @@ -1,7 +1,5 @@ # -*- coding: utf-8 -*- -# from sglang.op.lookahead import Lookahead, Param - import logging import os from typing import List, Tuple @@ -12,17 +10,17 @@ from torch.utils.cpp_extension import load logger = logging.getLogger(__name__) _abs_path = os.path.dirname(os.path.abspath(__file__)) -lookahead_cache_cpp = load( - name="lookahead_cache_cpp", +ngram_cache_cpp = load( + name="ngram_cache_cpp", sources=[ - f"{_abs_path}/lookahead_cache_binding.cpp", - f"{_abs_path}/lookahead.cpp", + f"{_abs_path}/ngram_cache_binding.cpp", + f"{_abs_path}/ngram.cpp", ], extra_cflags=["-O3", "-std=c++20"], ) -class LookaheadCache: +class NgramCache: def __init__( self, branch_length=18, @@ -34,7 +32,7 @@ class LookaheadCache: match_type="BFS", capacity=1000000, ): - param = lookahead_cache_cpp.Param() + param = ngram_cache_cpp.Param() param.branch_length = branch_length param.min_match_window_size = min_match_window_size param.max_match_window_size = max_match_window_size @@ -42,7 +40,7 @@ class LookaheadCache: param.max_bfs_breadth = max_bfs_breadth param.draft_token_num = draft_token_num param.match_type = match_type - self.cache = lookahead_cache_cpp.Lookahead(capacity, param) + self.cache = ngram_cache_cpp.Ngram(capacity, param) self.default_mask = np.ones((1, 1), dtype=np.int64) self.draft_token_num = draft_token_num @@ -131,7 +129,7 @@ if __name__ == "__main__": [1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 44, 55, 66, 77, 88, 99, 100], ] - cache = LookaheadCache(branch_length=12, draft_token_num=8) + cache = NgramCache(branch_length=12, draft_token_num=8) cache.batch_put(token_ids) cache.synchronize() diff --git a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp b/python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp similarity index 71% rename from python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp rename to python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp index 8c48a66ae..ac5b931f9 100644 --- a/python/sglang/srt/speculative/cpp_lookahead/lookahead_cache_binding.cpp +++ b/python/sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp @@ -1,19 +1,19 @@ #include #include -#include "lookahead.h" +#include "ngram.h" -PYBIND11_MODULE(lookahead_cache_cpp, m) { - using namespace lookahead; +PYBIND11_MODULE(ngram_cache_cpp, m) { + using namespace ngram; namespace py = pybind11; m.doc() = ""; - py::class_(m, "Lookahead") + py::class_(m, "Ngram") .def(py::init(), py::arg("capacity"), py::arg("param")) - .def("asyncInsert", &Lookahead::asyncInsert, "") - .def("batchMatch", &Lookahead::batchMatch, "") - .def("reset", &Lookahead::reset, "") - .def("synchronize", &Lookahead::synchronize, ""); + .def("asyncInsert", &Ngram::asyncInsert, "") + .def("batchMatch", &Ngram::batchMatch, "") + .def("reset", &Ngram::reset, "") + .def("synchronize", &Ngram::synchronize, ""); py::class_(m, "Param") .def(py::init<>()) @@ -35,9 +35,9 @@ PYBIND11_MODULE(lookahead_cache_cpp, m) { .def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "") .def("detail", &Param::detail, ""); - py::class_(m, "Result") + py::class_(m, "Result") .def(py::init<>()) - .def_readwrite("token", &Lookahead::Result::token) - .def_readwrite("mask", &Lookahead::Result::mask) - .def("truncate", &Lookahead::Result::truncate); + .def_readwrite("token", &Ngram::Result::token) + .def_readwrite("mask", &Ngram::Result::mask) + .def("truncate", &Ngram::Result::truncate); } diff --git a/python/sglang/srt/speculative/cpp_lookahead/param.h b/python/sglang/srt/speculative/cpp_ngram/param.h similarity index 98% rename from python/sglang/srt/speculative/cpp_lookahead/param.h rename to python/sglang/srt/speculative/cpp_ngram/param.h index 2d8b1f875..967832ad6 100644 --- a/python/sglang/srt/speculative/cpp_lookahead/param.h +++ b/python/sglang/srt/speculative/cpp_ngram/param.h @@ -9,7 +9,7 @@ #include #include -namespace lookahead { +namespace ngram { struct Param { bool enable; @@ -122,4 +122,4 @@ struct Param { } }; -} // namespace lookahead +} // namespace ngram diff --git a/python/sglang/srt/speculative/cpp_lookahead/queue.h b/python/sglang/srt/speculative/cpp_ngram/queue.h similarity index 100% rename from python/sglang/srt/speculative/cpp_lookahead/queue.h rename to python/sglang/srt/speculative/cpp_ngram/queue.h diff --git a/python/sglang/srt/speculative/lookahead_utils.py b/python/sglang/srt/speculative/ngram_utils.py similarity index 99% rename from python/sglang/srt/speculative/lookahead_utils.py rename to python/sglang/srt/speculative/ngram_utils.py index 5ca6cb025..d675d35b5 100644 --- a/python/sglang/srt/speculative/lookahead_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -42,7 +42,7 @@ elif is_hip(): @dataclass -class LookaheadVerifyInput: +class NgramVerifyInput: def __init__( self, draft_token: torch.Tensor, @@ -408,5 +408,5 @@ class LookaheadVerifyInput: def filter_batch(self, new_indices: torch.Tensor): pass - def merge_batch(self, spec_info: LookaheadVerifyInput): + def merge_batch(self, spec_info: NgramVerifyInput): pass diff --git a/python/sglang/srt/speculative/lookahead_worker.py b/python/sglang/srt/speculative/ngram_worker.py similarity index 86% rename from python/sglang/srt/speculative/lookahead_worker.py rename to python/sglang/srt/speculative/ngram_worker.py index 040078ac7..cb0155911 100644 --- a/python/sglang/srt/speculative/lookahead_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ForwardMode from sglang.srt.server_args import ServerArgs -from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache -from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput +from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache +from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import broadcast_pyobj @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) USE_FULL_MASK = True -class LOOKAHEADWorker: +class NGRAMWorker: def __init__( self, server_args: ServerArgs, @@ -38,9 +38,9 @@ class LOOKAHEADWorker: self.tp_rank = tp_rank self.page_size = server_args.page_size self.draft_token_num: int = server_args.speculative_num_draft_tokens - self.branch_length: int = server_args.speculative_lookahead_branch_length + self.branch_length: int = server_args.speculative_ngram_branch_length self.max_match_window_size: int = ( - server_args.speculative_lookahead_max_match_window_size + server_args.speculative_ngram_max_match_window_size ) self.max_batch_size = target_worker.max_running_requests @@ -48,18 +48,18 @@ class LOOKAHEADWorker: self._init_preallocated_tensors() - self.lookahead_cache = LookaheadCache( - min_match_window_size=server_args.speculative_lookahead_min_match_window_size, - max_match_window_size=server_args.speculative_lookahead_max_match_window_size, - min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth, - max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth, - capacity=server_args.speculative_lookahead_capacity, - branch_length=server_args.speculative_lookahead_branch_length, + self.ngram_cache = NgramCache( + min_match_window_size=server_args.speculative_ngram_min_match_window_size, + max_match_window_size=server_args.speculative_ngram_max_match_window_size, + min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth, + max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth, + capacity=server_args.speculative_ngram_capacity, + branch_length=server_args.speculative_ngram_branch_length, draft_token_num=server_args.speculative_num_draft_tokens, ) def clear_cache_pool(self): - self.lookahead_cache.reset() + self.ngram_cache.reset() def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int): seq2_len = len(seq2) @@ -124,14 +124,14 @@ class LOOKAHEADWorker: ) -> tuple[np.ndarray, np.ndarray]: bs = batch.batch_size() - self.lookahead_cache.synchronize() + self.ngram_cache.synchronize() batch_tokens = [] for req in batch.reqs: check_token = self._efficient_concat_last_n( req.origin_input_ids, req.output_ids, self.max_match_window_size ) batch_tokens.append(check_token) - req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens) + req_drafts, mask = self.ngram_cache.batch_get(batch_tokens) total_draft_token_num = len(req_drafts) # Check if speculative decoding is needed; here we always enforce it @@ -184,9 +184,9 @@ class LOOKAHEADWorker: tree_mask.append(req_mask.flatten()) tree_mask = torch.cat(tree_mask, dim=0) - batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD + batch.spec_algorithm = SpeculativeAlgorithm.NGRAM batch.forward_mode = ForwardMode.TARGET_VERIFY - batch.spec_info = LookaheadVerifyInput( + batch.spec_info = NgramVerifyInput( draft_tokens, tree_mask, positions, @@ -197,7 +197,7 @@ class LOOKAHEADWorker: ) batch.spec_info.prepare_for_verify(batch, self.page_size) - def _update_lookahead_cache(self, batch: ScheduleBatch): + def _update_ngram_cache(self, batch: ScheduleBatch): batch_tokens = [] for req in batch.reqs: # FIXME: Whether to insert 'extend' into the cache or not, after testing, @@ -209,7 +209,7 @@ class LOOKAHEADWorker: req.origin_input_ids, req.output_ids, self.branch_length ) batch_tokens.append(put_ids) - self.lookahead_cache.batch_put(batch_tokens) + self.ngram_cache.batch_put(batch_tokens) def forward_batch_speculative_generation(self, batch: ScheduleBatch): self._prepare_for_speculative_decoding(batch) @@ -227,7 +227,7 @@ class LOOKAHEADWorker: logits_output, next_token_ids, num_accepted_tokens = verify_input.verify( batch, logits_output, self.page_size ) - self._update_lookahead_cache(batch) + self._update_ngram_cache(batch) batch.forward_mode = ForwardMode.DECODE else: diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index a865d0ff6..64a02f19e 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum): EAGLE = auto() EAGLE3 = auto() STANDALONE = auto() - LOOKAHEAD = auto() + NGRAM = auto() def is_none(self): return self == SpeculativeAlgorithm.NONE @@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum): def is_standalone(self): return self == SpeculativeAlgorithm.STANDALONE - def is_lookahead(self): - return self == SpeculativeAlgorithm.LOOKAHEAD + def is_ngram(self): + return self == SpeculativeAlgorithm.NGRAM @staticmethod def from_string(name: str): @@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum): "EAGLE": SpeculativeAlgorithm.EAGLE, "EAGLE3": SpeculativeAlgorithm.EAGLE3, "STANDALONE": SpeculativeAlgorithm.STANDALONE, - "LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD, + "NGRAM": SpeculativeAlgorithm.NGRAM, None: SpeculativeAlgorithm.NONE, } if name is not None: diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index 1c5cd2fd1..2e9a16896 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -82,7 +82,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = ( "meta-llama/Llama-3.1-8B-Instruct" ) DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct" -DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct" +DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct" # Other use cases DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = ( diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index ea39e239a..87c271e20 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -314,7 +314,7 @@ set(SOURCES "csrc/kvcacheio/transfer.cu" "csrc/speculative/eagle_utils.cu" - "csrc/speculative/lookahead_utils.cu" + "csrc/speculative/ngram_utils.cu" "csrc/speculative/packbit.cu" "csrc/speculative/speculative_sampling.cu" diff --git a/sgl-kernel/csrc/speculative/lookahead_utils.cu b/sgl-kernel/csrc/speculative/ngram_utils.cu similarity index 100% rename from sgl-kernel/csrc/speculative/lookahead_utils.cu rename to sgl-kernel/csrc/speculative/ngram_utils.cu diff --git a/sgl-kernel/tests/speculative/test_lookahead_utils.py b/sgl-kernel/tests/speculative/test_ngram_utils.py similarity index 100% rename from sgl-kernel/tests/speculative/test_lookahead_utils.py rename to sgl-kernel/tests/speculative/test_ngram_utils.py diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 9c15c5ba8..5dbb7cfb7 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -79,7 +79,7 @@ suites = { TestFile("test_hidden_states.py", 55), TestFile("test_hybrid_attn_backend.py", 100), TestFile("test_standalone_speculative_decoding.py", 250), - TestFile("test_lookahead_speculative_decoding.py", 250), + TestFile("test_ngram_speculative_decoding.py", 250), TestFile("test_input_embeddings.py", 38), TestFile("test_io_struct.py", 8), TestFile("test_jinja_template_utils.py", 1), diff --git a/test/srt/test_lookahead_speculative_decoding.py b/test/srt/test_ngram_speculative_decoding.py similarity index 95% rename from test/srt/test_lookahead_speculative_decoding.py rename to test/srt/test_ngram_speculative_decoding.py index 1cf3e2101..c791915a8 100644 --- a/test/srt/test_lookahead_speculative_decoding.py +++ b/test/srt/test_ngram_speculative_decoding.py @@ -7,7 +7,7 @@ import requests from sglang.srt.utils import kill_process_tree from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k from sglang.test.test_utils import ( - DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST, + DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, DEFAULT_URL_FOR_TEST, CustomTestCase, @@ -23,7 +23,7 @@ DEFAULT_SERVER_ARGS = [ "--cuda-graph-max-bs", "8", "--speculative-algorithm", - "LOOKAHEAD", + "NGRAM", "--speculative-num-draft-tokens", "16", "--mem-fraction-static", @@ -33,7 +33,7 @@ DEFAULT_SERVER_ARGS = [ class TestStandaloneSpeculativeDecodingBase(CustomTestCase): - model = DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST + model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST base_url = DEFAULT_URL_FOR_TEST accuracy_threshold = 0.79 # derived tests need to override this spec_decode_threshold = 1.8 # derived spec decoding tests need to override this