[speculative decoding] rename lookahead to ngram (#11010)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com>
This commit is contained in:
@@ -103,8 +103,8 @@ dev = ["sglang[test]", "sglang[decord]"]
|
|||||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||||
"srt/layers/quantization/configs/*.json",
|
"srt/layers/quantization/configs/*.json",
|
||||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||||
"srt/speculative/cpp_lookahead/*.cpp",
|
"srt/speculative/cpp_ngram/*.cpp",
|
||||||
"srt/speculative/cpp_lookahead/*.h",
|
"srt/speculative/cpp_ngram/*.h",
|
||||||
]
|
]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
|
|||||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import (
|
||||||
get_int_env_var,
|
get_int_env_var,
|
||||||
is_flashinfer_available,
|
is_flashinfer_available,
|
||||||
@@ -344,9 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
@@ -453,9 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
@@ -673,9 +669,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -690,9 +684,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -718,9 +710,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -770,9 +760,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
disable_split_kv: Optional[bool] = None,
|
disable_split_kv: Optional[bool] = None,
|
||||||
):
|
):
|
||||||
@@ -806,9 +794,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
kv_start_idx: torch.Tensor,
|
kv_start_idx: torch.Tensor,
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
@@ -919,9 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
@@ -937,9 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
@@ -977,9 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -1026,9 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
@@ -1071,9 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
qo_indptr: torch.Tensor,
|
qo_indptr: torch.Tensor,
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
|
||||||
],
|
|
||||||
use_sliding_window_kv_pool: bool = False,
|
use_sliding_window_kv_pool: bool = False,
|
||||||
fixed_split_size: Optional[int] = None,
|
fixed_split_size: Optional[int] = None,
|
||||||
):
|
):
|
||||||
@@ -1102,7 +1078,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
custom_mask = None
|
custom_mask = None
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
|
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
|
||||||
)
|
)
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.configs.model_config import ModelConfig
|
from sglang.srt.configs.model_config import ModelConfig
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
@@ -953,9 +953,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
|
||||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
None
|
||||||
] = None
|
)
|
||||||
|
|
||||||
# Whether to return hidden states
|
# Whether to return hidden states
|
||||||
return_hidden_states: bool = False
|
return_hidden_states: bool = False
|
||||||
@@ -1608,7 +1608,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|||||||
if (
|
if (
|
||||||
self.spec_algorithm.is_eagle()
|
self.spec_algorithm.is_eagle()
|
||||||
or self.spec_algorithm.is_standalone()
|
or self.spec_algorithm.is_standalone()
|
||||||
or self.spec_algorithm.is_lookahead()
|
or self.spec_algorithm.is_ngram()
|
||||||
):
|
):
|
||||||
# if spec decoding is used, the decode batch is prepared inside
|
# if spec decoding is used, the decode batch is prepared inside
|
||||||
# `forward_batch_speculative_generation` after running draft models.
|
# `forward_batch_speculative_generation` after running draft models.
|
||||||
@@ -1984,9 +1984,9 @@ class ModelWorkerBatch:
|
|||||||
|
|
||||||
# Speculative decoding
|
# Speculative decoding
|
||||||
spec_algorithm: SpeculativeAlgorithm = None
|
spec_algorithm: SpeculativeAlgorithm = None
|
||||||
spec_info: Optional[
|
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
|
||||||
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
|
None
|
||||||
] = None
|
)
|
||||||
# If set, the output of the batch contains the hidden states of the run.
|
# If set, the output of the batch contains the hidden states of the run.
|
||||||
capture_hidden_mode: CaptureHiddenMode = None
|
capture_hidden_mode: CaptureHiddenMode = None
|
||||||
hicache_consumer_index: int = -1
|
hicache_consumer_index: int = -1
|
||||||
|
|||||||
@@ -388,10 +388,10 @@ class Scheduler(
|
|||||||
target_worker=self.tp_worker,
|
target_worker=self.tp_worker,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
)
|
)
|
||||||
elif self.spec_algorithm.is_lookahead():
|
elif self.spec_algorithm.is_ngram():
|
||||||
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
|
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
||||||
|
|
||||||
self.draft_worker = LOOKAHEADWorker(
|
self.draft_worker = NGRAMWorker(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
tp_rank=tp_rank,
|
tp_rank=tp_rank,
|
||||||
moe_ep_rank=moe_ep_rank,
|
moe_ep_rank=moe_ep_rank,
|
||||||
@@ -826,7 +826,7 @@ class Scheduler(
|
|||||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
draft_token_to_kv_pool=(
|
draft_token_to_kv_pool=(
|
||||||
None
|
None
|
||||||
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
@@ -863,7 +863,7 @@ class Scheduler(
|
|||||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||||
draft_token_to_kv_pool=(
|
draft_token_to_kv_pool=(
|
||||||
None
|
None
|
||||||
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
if self.draft_worker is None or self.spec_algorithm.is_ngram()
|
||||||
else self.draft_worker.model_runner.token_to_kv_pool
|
else self.draft_worker.model_runner.token_to_kv_pool
|
||||||
),
|
),
|
||||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||||
|
|||||||
@@ -246,7 +246,7 @@ class CudaGraphRunner:
|
|||||||
if (
|
if (
|
||||||
model_runner.spec_algorithm.is_eagle()
|
model_runner.spec_algorithm.is_eagle()
|
||||||
or model_runner.spec_algorithm.is_standalone()
|
or model_runner.spec_algorithm.is_standalone()
|
||||||
or model_runner.spec_algorithm.is_lookahead()
|
or model_runner.spec_algorithm.is_ngram()
|
||||||
):
|
):
|
||||||
if self.model_runner.is_draft_worker:
|
if self.model_runner.is_draft_worker:
|
||||||
raise RuntimeError("This should not happen")
|
raise RuntimeError("This should not happen")
|
||||||
@@ -413,12 +413,12 @@ class CudaGraphRunner:
|
|||||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||||
)
|
)
|
||||||
|
|
||||||
is_lookahead_supported = (
|
is_ngram_supported = (
|
||||||
(
|
(
|
||||||
forward_batch.batch_size * self.num_tokens_per_bs
|
forward_batch.batch_size * self.num_tokens_per_bs
|
||||||
== forward_batch.input_ids.numel()
|
== forward_batch.input_ids.numel()
|
||||||
)
|
)
|
||||||
if self.model_runner.spec_algorithm.is_lookahead()
|
if self.model_runner.spec_algorithm.is_ngram()
|
||||||
else True
|
else True
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -427,7 +427,7 @@ class CudaGraphRunner:
|
|||||||
and is_encoder_lens_supported
|
and is_encoder_lens_supported
|
||||||
and is_tbo_supported
|
and is_tbo_supported
|
||||||
and capture_hidden_mode_matches
|
and capture_hidden_mode_matches
|
||||||
and is_lookahead_supported
|
and is_ngram_supported
|
||||||
)
|
)
|
||||||
|
|
||||||
def capture(self) -> None:
|
def capture(self) -> None:
|
||||||
@@ -838,10 +838,10 @@ class CudaGraphRunner:
|
|||||||
seq_lens_cpu=None,
|
seq_lens_cpu=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif self.model_runner.spec_algorithm.is_lookahead():
|
elif self.model_runner.spec_algorithm.is_ngram():
|
||||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||||
|
|
||||||
spec_info = LookaheadVerifyInput(
|
spec_info = NgramVerifyInput(
|
||||||
draft_token=None,
|
draft_token=None,
|
||||||
tree_mask=self.custom_mask,
|
tree_mask=self.custom_mask,
|
||||||
positions=None,
|
positions=None,
|
||||||
|
|||||||
@@ -286,14 +286,14 @@ class ServerArgs:
|
|||||||
speculative_accept_threshold_acc: float = 1.0
|
speculative_accept_threshold_acc: float = 1.0
|
||||||
speculative_token_map: Optional[str] = None
|
speculative_token_map: Optional[str] = None
|
||||||
speculative_attention_mode: str = "prefill"
|
speculative_attention_mode: str = "prefill"
|
||||||
# For lookahead only
|
# For ngram only
|
||||||
speculative_lookahead_min_match_window_size: int = 1
|
speculative_ngram_min_match_window_size: int = 1
|
||||||
speculative_lookahead_max_match_window_size: int = 12
|
speculative_ngram_max_match_window_size: int = 12
|
||||||
speculative_lookahead_min_bfs_breadth: int = 1
|
speculative_ngram_min_bfs_breadth: int = 1
|
||||||
speculative_lookahead_max_bfs_breadth: int = 10
|
speculative_ngram_max_bfs_breadth: int = 10
|
||||||
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
|
speculative_ngram_match_type: Literal["BFS", "PROB"] = "BFS"
|
||||||
speculative_lookahead_branch_length: int = 18
|
speculative_ngram_branch_length: int = 18
|
||||||
speculative_lookahead_capacity: int = 10 * 1000 * 1000
|
speculative_ngram_capacity: int = 10 * 1000 * 1000
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
ep_size: int = 1
|
ep_size: int = 1
|
||||||
@@ -566,7 +566,7 @@ class ServerArgs:
|
|||||||
# Standalone speculative decoding needs more memory than other speculative
|
# Standalone speculative decoding needs more memory than other speculative
|
||||||
# decoding algorithms since the draft model is typically larger.
|
# decoding algorithms since the draft model is typically larger.
|
||||||
reserved_mem += 6 * 1024
|
reserved_mem += 6 * 1024
|
||||||
elif self.speculative_algorithm != "LOOKAHEAD":
|
elif self.speculative_algorithm != "NGRAM":
|
||||||
reserved_mem += 2 * 1024
|
reserved_mem += 2 * 1024
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
reserved_mem += 4 * 1024
|
reserved_mem += 4 * 1024
|
||||||
@@ -1024,23 +1024,23 @@ class ServerArgs:
|
|||||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.speculative_algorithm == "LOOKAHEAD":
|
if self.speculative_algorithm == "NGRAM":
|
||||||
if not self.device.startswith("cuda"):
|
if not self.device.startswith("cuda"):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Lookahead speculative decoding only supports CUDA device."
|
"Ngram speculative decoding only supports CUDA device."
|
||||||
)
|
)
|
||||||
if self.max_running_requests is None:
|
if self.max_running_requests is None:
|
||||||
self.max_running_requests = 48
|
self.max_running_requests = 48
|
||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
self.enable_mixed_chunk = False
|
self.enable_mixed_chunk = False
|
||||||
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
|
self.speculative_eagle_topk = self.speculative_ngram_max_bfs_breadth
|
||||||
if self.speculative_num_draft_tokens is None:
|
if self.speculative_num_draft_tokens is None:
|
||||||
self.speculative_num_draft_tokens = (
|
self.speculative_num_draft_tokens = (
|
||||||
self.speculative_lookahead_max_match_window_size
|
self.speculative_ngram_max_match_window_size
|
||||||
)
|
)
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"The overlap scheduler and mixed chunked prefill are disabled because of "
|
"The overlap scheduler and mixed chunked prefill are disabled because of "
|
||||||
"using lookahead speculative decoding."
|
"using ngram speculative decoding."
|
||||||
)
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
@@ -1052,9 +1052,9 @@ class ServerArgs:
|
|||||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
||||||
)
|
)
|
||||||
if self.enable_dp_attention:
|
if self.enable_dp_attention:
|
||||||
# TODO: support dp attention for lookahead speculative decoding
|
# TODO: support dp attention for ngram speculative decoding
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Currently lookahead speculative decoding does not support dp attention."
|
"Currently ngram speculative decoding does not support dp attention."
|
||||||
)
|
)
|
||||||
|
|
||||||
def _handle_load_format(self):
|
def _handle_load_format(self):
|
||||||
@@ -1921,7 +1921,7 @@ class ServerArgs:
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-algorithm",
|
"--speculative-algorithm",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
|
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "NGRAM"],
|
||||||
help="Speculative algorithm.",
|
help="Speculative algorithm.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
@@ -1981,49 +1981,49 @@ class ServerArgs:
|
|||||||
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
||||||
default=ServerArgs.speculative_attention_mode,
|
default=ServerArgs.speculative_attention_mode,
|
||||||
)
|
)
|
||||||
# Lookahead speculative decoding
|
# Ngram speculative decoding
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-min-match-window-size",
|
"--speculative-ngram-min-match-window-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_min_match_window_size,
|
default=ServerArgs.speculative_ngram_min_match_window_size,
|
||||||
help="The minimum window size for pattern matching in lookahead speculative decoding.",
|
help="The minimum window size for pattern matching in ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-max-match-window-size",
|
"--speculative-ngram-max-match-window-size",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_max_match_window_size,
|
default=ServerArgs.speculative_ngram_max_match_window_size,
|
||||||
help="The maximum window size for pattern matching in lookahead speculative decoding.",
|
help="The maximum window size for pattern matching in ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-min-bfs-breadth",
|
"--speculative-ngram-min-bfs-breadth",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_min_bfs_breadth,
|
default=ServerArgs.speculative_ngram_min_bfs_breadth,
|
||||||
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
help="The minimum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-max-bfs-breadth",
|
"--speculative-ngram-max-bfs-breadth",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_max_bfs_breadth,
|
default=ServerArgs.speculative_ngram_max_bfs_breadth,
|
||||||
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
help="The maximum breadth for BFS (Breadth-First Search) in ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-match-type",
|
"--speculative-ngram-match-type",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["BFS", "PROB"],
|
choices=["BFS", "PROB"],
|
||||||
default=ServerArgs.speculative_lookahead_match_type,
|
default=ServerArgs.speculative_ngram_match_type,
|
||||||
help="The match type for cache tree.",
|
help="The match type for cache tree.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-branch-length",
|
"--speculative-ngram-branch-length",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_branch_length,
|
default=ServerArgs.speculative_ngram_branch_length,
|
||||||
help="The branch length for lookahead speculative decoding.",
|
help="The branch length for ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--speculative-lookahead-capacity",
|
"--speculative-ngram-capacity",
|
||||||
type=int,
|
type=int,
|
||||||
default=ServerArgs.speculative_lookahead_capacity,
|
default=ServerArgs.speculative_ngram_capacity,
|
||||||
help="The cache capacity for lookahead speculative decoding.",
|
help="The cache capacity for ngram speculative decoding.",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expert parallelism
|
# Expert parallelism
|
||||||
|
|||||||
@@ -1,16 +1,16 @@
|
|||||||
#include "lookahead.h"
|
#include "ngram.h"
|
||||||
|
|
||||||
#include <limits>
|
#include <limits>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace lookahead {
|
namespace ngram {
|
||||||
|
|
||||||
struct Node {
|
struct Node {
|
||||||
std::unordered_map<int32_t, int32_t> next;
|
std::unordered_map<int32_t, int32_t> next;
|
||||||
};
|
};
|
||||||
|
|
||||||
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
Ngram::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
||||||
Lookahead::Result info;
|
Ngram::Result info;
|
||||||
std::vector<int32_t> prevs;
|
std::vector<int32_t> prevs;
|
||||||
info.token.reserve(draft_token_num);
|
info.token.reserve(draft_token_num);
|
||||||
prevs.reserve(draft_token_num);
|
prevs.reserve(draft_token_num);
|
||||||
@@ -50,7 +50,7 @@ Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<No
|
|||||||
return info;
|
return info;
|
||||||
}
|
}
|
||||||
|
|
||||||
Lookahead::Lookahead(size_t capacity, const Param& param) {
|
Ngram::Ngram(size_t capacity, const Param& param) {
|
||||||
param_ = param;
|
param_ = param;
|
||||||
nodes_.resize(capacity);
|
nodes_.resize(capacity);
|
||||||
for (auto& node : nodes_) {
|
for (auto& node : nodes_) {
|
||||||
@@ -116,17 +116,16 @@ Lookahead::Lookahead(size_t capacity, const Param& param) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
quit_flag_ = false;
|
quit_flag_ = false;
|
||||||
insert_worker_ = std::thread(&Lookahead::insert, this);
|
insert_worker_ = std::thread(&Ngram::insert, this);
|
||||||
}
|
}
|
||||||
|
|
||||||
Lookahead::~Lookahead() {
|
Ngram::~Ngram() {
|
||||||
quit_flag_ = true;
|
quit_flag_ = true;
|
||||||
insert_queue_.close();
|
insert_queue_.close();
|
||||||
insert_worker_.join();
|
insert_worker_.join();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<TrieNode*, int32_t>>
|
std::vector<std::pair<TrieNode*, int32_t>> Ngram::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||||
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
|
||||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||||
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
||||||
auto max_match_window_size = param_.max_match_window_size;
|
auto max_match_window_size = param_.max_match_window_size;
|
||||||
@@ -154,7 +153,7 @@ Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
|||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lookahead::squeeze(size_t count) {
|
void Ngram::squeeze(size_t count) {
|
||||||
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
||||||
throw std::runtime_error(
|
throw std::runtime_error(
|
||||||
"Insufficient node size to release required nodes. "
|
"Insufficient node size to release required nodes. "
|
||||||
@@ -177,13 +176,13 @@ void Lookahead::squeeze(size_t count) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lookahead::synchronize() const {
|
void Ngram::synchronize() const {
|
||||||
while (!insert_queue_.empty()) {
|
while (!insert_queue_.empty()) {
|
||||||
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lookahead::insert() {
|
void Ngram::insert() {
|
||||||
while (!quit_flag_) {
|
while (!quit_flag_) {
|
||||||
std::vector<int32_t> data;
|
std::vector<int32_t> data;
|
||||||
if (!insert_queue_.dequeue(data)) {
|
if (!insert_queue_.dequeue(data)) {
|
||||||
@@ -239,13 +238,13 @@ void Lookahead::insert() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
void Ngram::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
||||||
for (auto&& token : tokens) {
|
for (auto&& token : tokens) {
|
||||||
insert_queue_.enqueue(std::move(token));
|
insert_queue_.enqueue(std::move(token));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
Ngram::Result Ngram::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||||
|
|
||||||
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
||||||
@@ -284,7 +283,7 @@ Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t
|
|||||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||||
}
|
}
|
||||||
|
|
||||||
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
Ngram::Result Ngram::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||||
|
|
||||||
@@ -346,10 +345,10 @@ Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_
|
|||||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||||
}
|
}
|
||||||
|
|
||||||
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
Ngram::Result Ngram::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
||||||
std::unique_lock<std::mutex> lock(mutex_);
|
std::unique_lock<std::mutex> lock(mutex_);
|
||||||
Result merged_result;
|
Result merged_result;
|
||||||
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
|
auto match_func = param_.match_type == "BFS" ? &Ngram::matchBFS : &Ngram::matchProb;
|
||||||
for (const auto& tks : tokens) {
|
for (const auto& tks : tokens) {
|
||||||
Result res = (this->*match_func)(tks, tokens.size());
|
Result res = (this->*match_func)(tks, tokens.size());
|
||||||
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
||||||
@@ -358,7 +357,7 @@ Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>&
|
|||||||
return merged_result;
|
return merged_result;
|
||||||
}
|
}
|
||||||
|
|
||||||
void Lookahead::Result::truncate(size_t n) {
|
void Ngram::Result::truncate(size_t n) {
|
||||||
if (n < token.size()) {
|
if (n < token.size()) {
|
||||||
int full_n = token.size();
|
int full_n = token.size();
|
||||||
for (int i = 1; i < n; ++i) {
|
for (int i = 1; i < n; ++i) {
|
||||||
@@ -369,4 +368,4 @@ void Lookahead::Result::truncate(size_t n) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace lookahead
|
} // namespace ngram
|
||||||
@@ -15,7 +15,7 @@
|
|||||||
#include "param.h"
|
#include "param.h"
|
||||||
#include "queue.h"
|
#include "queue.h"
|
||||||
|
|
||||||
namespace lookahead {
|
namespace ngram {
|
||||||
|
|
||||||
struct TrieNode {
|
struct TrieNode {
|
||||||
std::unordered_map<int32_t, TrieNode*> child;
|
std::unordered_map<int32_t, TrieNode*> child;
|
||||||
@@ -34,7 +34,7 @@ struct TrieNode {
|
|||||||
std::multiset<TrieNode*, CompareByFreq> sorted_children;
|
std::multiset<TrieNode*, CompareByFreq> sorted_children;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Lookahead {
|
class Ngram {
|
||||||
std::vector<TrieNode> nodes_;
|
std::vector<TrieNode> nodes_;
|
||||||
std::vector<TrieNode*> node_pool_;
|
std::vector<TrieNode*> node_pool_;
|
||||||
size_t free_node_count_;
|
size_t free_node_count_;
|
||||||
@@ -61,12 +61,12 @@ class Lookahead {
|
|||||||
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
|
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
Lookahead(size_t capacity, const Param& param);
|
Ngram(size_t capacity, const Param& param);
|
||||||
Lookahead() = default;
|
Ngram() = default;
|
||||||
~Lookahead();
|
~Ngram();
|
||||||
|
|
||||||
static Lookahead& instance() {
|
static Ngram& instance() {
|
||||||
static Lookahead instance;
|
static Ngram instance;
|
||||||
return instance;
|
return instance;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -107,4 +107,4 @@ class Lookahead {
|
|||||||
void insert();
|
void insert();
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lookahead
|
} // namespace ngram
|
||||||
@@ -1,7 +1,5 @@
|
|||||||
# -*- coding: utf-8 -*-
|
# -*- coding: utf-8 -*-
|
||||||
|
|
||||||
# from sglang.op.lookahead import Lookahead, Param
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
@@ -12,17 +10,17 @@ from torch.utils.cpp_extension import load
|
|||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||||
lookahead_cache_cpp = load(
|
ngram_cache_cpp = load(
|
||||||
name="lookahead_cache_cpp",
|
name="ngram_cache_cpp",
|
||||||
sources=[
|
sources=[
|
||||||
f"{_abs_path}/lookahead_cache_binding.cpp",
|
f"{_abs_path}/ngram_cache_binding.cpp",
|
||||||
f"{_abs_path}/lookahead.cpp",
|
f"{_abs_path}/ngram.cpp",
|
||||||
],
|
],
|
||||||
extra_cflags=["-O3", "-std=c++20"],
|
extra_cflags=["-O3", "-std=c++20"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LookaheadCache:
|
class NgramCache:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
branch_length=18,
|
branch_length=18,
|
||||||
@@ -34,7 +32,7 @@ class LookaheadCache:
|
|||||||
match_type="BFS",
|
match_type="BFS",
|
||||||
capacity=1000000,
|
capacity=1000000,
|
||||||
):
|
):
|
||||||
param = lookahead_cache_cpp.Param()
|
param = ngram_cache_cpp.Param()
|
||||||
param.branch_length = branch_length
|
param.branch_length = branch_length
|
||||||
param.min_match_window_size = min_match_window_size
|
param.min_match_window_size = min_match_window_size
|
||||||
param.max_match_window_size = max_match_window_size
|
param.max_match_window_size = max_match_window_size
|
||||||
@@ -42,7 +40,7 @@ class LookaheadCache:
|
|||||||
param.max_bfs_breadth = max_bfs_breadth
|
param.max_bfs_breadth = max_bfs_breadth
|
||||||
param.draft_token_num = draft_token_num
|
param.draft_token_num = draft_token_num
|
||||||
param.match_type = match_type
|
param.match_type = match_type
|
||||||
self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
|
self.cache = ngram_cache_cpp.Ngram(capacity, param)
|
||||||
|
|
||||||
self.default_mask = np.ones((1, 1), dtype=np.int64)
|
self.default_mask = np.ones((1, 1), dtype=np.int64)
|
||||||
self.draft_token_num = draft_token_num
|
self.draft_token_num = draft_token_num
|
||||||
@@ -131,7 +129,7 @@ if __name__ == "__main__":
|
|||||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||||
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
|
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
|
||||||
]
|
]
|
||||||
cache = LookaheadCache(branch_length=12, draft_token_num=8)
|
cache = NgramCache(branch_length=12, draft_token_num=8)
|
||||||
cache.batch_put(token_ids)
|
cache.batch_put(token_ids)
|
||||||
|
|
||||||
cache.synchronize()
|
cache.synchronize()
|
||||||
@@ -1,19 +1,19 @@
|
|||||||
#include <pybind11/pybind11.h>
|
#include <pybind11/pybind11.h>
|
||||||
#include <pybind11/stl.h>
|
#include <pybind11/stl.h>
|
||||||
|
|
||||||
#include "lookahead.h"
|
#include "ngram.h"
|
||||||
|
|
||||||
PYBIND11_MODULE(lookahead_cache_cpp, m) {
|
PYBIND11_MODULE(ngram_cache_cpp, m) {
|
||||||
using namespace lookahead;
|
using namespace ngram;
|
||||||
namespace py = pybind11;
|
namespace py = pybind11;
|
||||||
m.doc() = "";
|
m.doc() = "";
|
||||||
|
|
||||||
py::class_<Lookahead>(m, "Lookahead")
|
py::class_<Ngram>(m, "Ngram")
|
||||||
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
|
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
|
||||||
.def("asyncInsert", &Lookahead::asyncInsert, "")
|
.def("asyncInsert", &Ngram::asyncInsert, "")
|
||||||
.def("batchMatch", &Lookahead::batchMatch, "")
|
.def("batchMatch", &Ngram::batchMatch, "")
|
||||||
.def("reset", &Lookahead::reset, "")
|
.def("reset", &Ngram::reset, "")
|
||||||
.def("synchronize", &Lookahead::synchronize, "");
|
.def("synchronize", &Ngram::synchronize, "");
|
||||||
|
|
||||||
py::class_<Param>(m, "Param")
|
py::class_<Param>(m, "Param")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
@@ -35,9 +35,9 @@ PYBIND11_MODULE(lookahead_cache_cpp, m) {
|
|||||||
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
|
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
|
||||||
.def("detail", &Param::detail, "");
|
.def("detail", &Param::detail, "");
|
||||||
|
|
||||||
py::class_<Lookahead::Result>(m, "Result")
|
py::class_<Ngram::Result>(m, "Result")
|
||||||
.def(py::init<>())
|
.def(py::init<>())
|
||||||
.def_readwrite("token", &Lookahead::Result::token)
|
.def_readwrite("token", &Ngram::Result::token)
|
||||||
.def_readwrite("mask", &Lookahead::Result::mask)
|
.def_readwrite("mask", &Ngram::Result::mask)
|
||||||
.def("truncate", &Lookahead::Result::truncate);
|
.def("truncate", &Ngram::Result::truncate);
|
||||||
}
|
}
|
||||||
@@ -9,7 +9,7 @@
|
|||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
namespace lookahead {
|
namespace ngram {
|
||||||
|
|
||||||
struct Param {
|
struct Param {
|
||||||
bool enable;
|
bool enable;
|
||||||
@@ -122,4 +122,4 @@ struct Param {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace lookahead
|
} // namespace ngram
|
||||||
@@ -42,7 +42,7 @@ elif is_hip():
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LookaheadVerifyInput:
|
class NgramVerifyInput:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
draft_token: torch.Tensor,
|
draft_token: torch.Tensor,
|
||||||
@@ -408,5 +408,5 @@ class LookaheadVerifyInput:
|
|||||||
def filter_batch(self, new_indices: torch.Tensor):
|
def filter_batch(self, new_indices: torch.Tensor):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def merge_batch(self, spec_info: LookaheadVerifyInput):
|
def merge_batch(self, spec_info: NgramVerifyInput):
|
||||||
pass
|
pass
|
||||||
@@ -12,8 +12,8 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|||||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
|
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
from sglang.srt.utils import broadcast_pyobj
|
from sglang.srt.utils import broadcast_pyobj
|
||||||
|
|
||||||
@@ -22,7 +22,7 @@ logger = logging.getLogger(__name__)
|
|||||||
USE_FULL_MASK = True
|
USE_FULL_MASK = True
|
||||||
|
|
||||||
|
|
||||||
class LOOKAHEADWorker:
|
class NGRAMWorker:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
@@ -38,9 +38,9 @@ class LOOKAHEADWorker:
|
|||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
self.page_size = server_args.page_size
|
self.page_size = server_args.page_size
|
||||||
self.draft_token_num: int = server_args.speculative_num_draft_tokens
|
self.draft_token_num: int = server_args.speculative_num_draft_tokens
|
||||||
self.branch_length: int = server_args.speculative_lookahead_branch_length
|
self.branch_length: int = server_args.speculative_ngram_branch_length
|
||||||
self.max_match_window_size: int = (
|
self.max_match_window_size: int = (
|
||||||
server_args.speculative_lookahead_max_match_window_size
|
server_args.speculative_ngram_max_match_window_size
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_batch_size = target_worker.max_running_requests
|
self.max_batch_size = target_worker.max_running_requests
|
||||||
@@ -48,18 +48,18 @@ class LOOKAHEADWorker:
|
|||||||
|
|
||||||
self._init_preallocated_tensors()
|
self._init_preallocated_tensors()
|
||||||
|
|
||||||
self.lookahead_cache = LookaheadCache(
|
self.ngram_cache = NgramCache(
|
||||||
min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
|
min_match_window_size=server_args.speculative_ngram_min_match_window_size,
|
||||||
max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
|
max_match_window_size=server_args.speculative_ngram_max_match_window_size,
|
||||||
min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
|
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
|
||||||
max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
|
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
|
||||||
capacity=server_args.speculative_lookahead_capacity,
|
capacity=server_args.speculative_ngram_capacity,
|
||||||
branch_length=server_args.speculative_lookahead_branch_length,
|
branch_length=server_args.speculative_ngram_branch_length,
|
||||||
draft_token_num=server_args.speculative_num_draft_tokens,
|
draft_token_num=server_args.speculative_num_draft_tokens,
|
||||||
)
|
)
|
||||||
|
|
||||||
def clear_cache_pool(self):
|
def clear_cache_pool(self):
|
||||||
self.lookahead_cache.reset()
|
self.ngram_cache.reset()
|
||||||
|
|
||||||
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
|
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
|
||||||
seq2_len = len(seq2)
|
seq2_len = len(seq2)
|
||||||
@@ -124,14 +124,14 @@ class LOOKAHEADWorker:
|
|||||||
) -> tuple[np.ndarray, np.ndarray]:
|
) -> tuple[np.ndarray, np.ndarray]:
|
||||||
bs = batch.batch_size()
|
bs = batch.batch_size()
|
||||||
|
|
||||||
self.lookahead_cache.synchronize()
|
self.ngram_cache.synchronize()
|
||||||
batch_tokens = []
|
batch_tokens = []
|
||||||
for req in batch.reqs:
|
for req in batch.reqs:
|
||||||
check_token = self._efficient_concat_last_n(
|
check_token = self._efficient_concat_last_n(
|
||||||
req.origin_input_ids, req.output_ids, self.max_match_window_size
|
req.origin_input_ids, req.output_ids, self.max_match_window_size
|
||||||
)
|
)
|
||||||
batch_tokens.append(check_token)
|
batch_tokens.append(check_token)
|
||||||
req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
|
req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
|
||||||
total_draft_token_num = len(req_drafts)
|
total_draft_token_num = len(req_drafts)
|
||||||
|
|
||||||
# Check if speculative decoding is needed; here we always enforce it
|
# Check if speculative decoding is needed; here we always enforce it
|
||||||
@@ -184,9 +184,9 @@ class LOOKAHEADWorker:
|
|||||||
tree_mask.append(req_mask.flatten())
|
tree_mask.append(req_mask.flatten())
|
||||||
tree_mask = torch.cat(tree_mask, dim=0)
|
tree_mask = torch.cat(tree_mask, dim=0)
|
||||||
|
|
||||||
batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
|
batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
|
||||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||||
batch.spec_info = LookaheadVerifyInput(
|
batch.spec_info = NgramVerifyInput(
|
||||||
draft_tokens,
|
draft_tokens,
|
||||||
tree_mask,
|
tree_mask,
|
||||||
positions,
|
positions,
|
||||||
@@ -197,7 +197,7 @@ class LOOKAHEADWorker:
|
|||||||
)
|
)
|
||||||
batch.spec_info.prepare_for_verify(batch, self.page_size)
|
batch.spec_info.prepare_for_verify(batch, self.page_size)
|
||||||
|
|
||||||
def _update_lookahead_cache(self, batch: ScheduleBatch):
|
def _update_ngram_cache(self, batch: ScheduleBatch):
|
||||||
batch_tokens = []
|
batch_tokens = []
|
||||||
for req in batch.reqs:
|
for req in batch.reqs:
|
||||||
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
|
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
|
||||||
@@ -209,7 +209,7 @@ class LOOKAHEADWorker:
|
|||||||
req.origin_input_ids, req.output_ids, self.branch_length
|
req.origin_input_ids, req.output_ids, self.branch_length
|
||||||
)
|
)
|
||||||
batch_tokens.append(put_ids)
|
batch_tokens.append(put_ids)
|
||||||
self.lookahead_cache.batch_put(batch_tokens)
|
self.ngram_cache.batch_put(batch_tokens)
|
||||||
|
|
||||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||||
self._prepare_for_speculative_decoding(batch)
|
self._prepare_for_speculative_decoding(batch)
|
||||||
@@ -227,7 +227,7 @@ class LOOKAHEADWorker:
|
|||||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||||
batch, logits_output, self.page_size
|
batch, logits_output, self.page_size
|
||||||
)
|
)
|
||||||
self._update_lookahead_cache(batch)
|
self._update_ngram_cache(batch)
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
|
|
||||||
else:
|
else:
|
||||||
@@ -6,7 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
EAGLE = auto()
|
EAGLE = auto()
|
||||||
EAGLE3 = auto()
|
EAGLE3 = auto()
|
||||||
STANDALONE = auto()
|
STANDALONE = auto()
|
||||||
LOOKAHEAD = auto()
|
NGRAM = auto()
|
||||||
|
|
||||||
def is_none(self):
|
def is_none(self):
|
||||||
return self == SpeculativeAlgorithm.NONE
|
return self == SpeculativeAlgorithm.NONE
|
||||||
@@ -20,8 +20,8 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
def is_standalone(self):
|
def is_standalone(self):
|
||||||
return self == SpeculativeAlgorithm.STANDALONE
|
return self == SpeculativeAlgorithm.STANDALONE
|
||||||
|
|
||||||
def is_lookahead(self):
|
def is_ngram(self):
|
||||||
return self == SpeculativeAlgorithm.LOOKAHEAD
|
return self == SpeculativeAlgorithm.NGRAM
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_string(name: str):
|
def from_string(name: str):
|
||||||
@@ -29,7 +29,7 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||||
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
||||||
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
||||||
"LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
|
"NGRAM": SpeculativeAlgorithm.NGRAM,
|
||||||
None: SpeculativeAlgorithm.NONE,
|
None: SpeculativeAlgorithm.NONE,
|
||||||
}
|
}
|
||||||
if name is not None:
|
if name is not None:
|
||||||
|
|||||||
@@ -82,7 +82,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
|||||||
"meta-llama/Llama-3.1-8B-Instruct"
|
"meta-llama/Llama-3.1-8B-Instruct"
|
||||||
)
|
)
|
||||||
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||||
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||||
|
|
||||||
# Other use cases
|
# Other use cases
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
||||||
|
|||||||
@@ -314,7 +314,7 @@ set(SOURCES
|
|||||||
"csrc/kvcacheio/transfer.cu"
|
"csrc/kvcacheio/transfer.cu"
|
||||||
|
|
||||||
"csrc/speculative/eagle_utils.cu"
|
"csrc/speculative/eagle_utils.cu"
|
||||||
"csrc/speculative/lookahead_utils.cu"
|
"csrc/speculative/ngram_utils.cu"
|
||||||
"csrc/speculative/packbit.cu"
|
"csrc/speculative/packbit.cu"
|
||||||
"csrc/speculative/speculative_sampling.cu"
|
"csrc/speculative/speculative_sampling.cu"
|
||||||
|
|
||||||
|
|||||||
@@ -79,7 +79,7 @@ suites = {
|
|||||||
TestFile("test_hidden_states.py", 55),
|
TestFile("test_hidden_states.py", 55),
|
||||||
TestFile("test_hybrid_attn_backend.py", 100),
|
TestFile("test_hybrid_attn_backend.py", 100),
|
||||||
TestFile("test_standalone_speculative_decoding.py", 250),
|
TestFile("test_standalone_speculative_decoding.py", 250),
|
||||||
TestFile("test_lookahead_speculative_decoding.py", 250),
|
TestFile("test_ngram_speculative_decoding.py", 250),
|
||||||
TestFile("test_input_embeddings.py", 38),
|
TestFile("test_input_embeddings.py", 38),
|
||||||
TestFile("test_io_struct.py", 8),
|
TestFile("test_io_struct.py", 8),
|
||||||
TestFile("test_jinja_template_utils.py", 1),
|
TestFile("test_jinja_template_utils.py", 1),
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import requests
|
|||||||
from sglang.srt.utils import kill_process_tree
|
from sglang.srt.utils import kill_process_tree
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST,
|
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
DEFAULT_URL_FOR_TEST,
|
DEFAULT_URL_FOR_TEST,
|
||||||
CustomTestCase,
|
CustomTestCase,
|
||||||
@@ -23,7 +23,7 @@ DEFAULT_SERVER_ARGS = [
|
|||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
"8",
|
"8",
|
||||||
"--speculative-algorithm",
|
"--speculative-algorithm",
|
||||||
"LOOKAHEAD",
|
"NGRAM",
|
||||||
"--speculative-num-draft-tokens",
|
"--speculative-num-draft-tokens",
|
||||||
"16",
|
"16",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
@@ -33,7 +33,7 @@ DEFAULT_SERVER_ARGS = [
|
|||||||
|
|
||||||
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
|
class TestStandaloneSpeculativeDecodingBase(CustomTestCase):
|
||||||
|
|
||||||
model = DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST
|
model = DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST
|
||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
accuracy_threshold = 0.79 # derived tests need to override this
|
accuracy_threshold = 0.79 # derived tests need to override this
|
||||||
spec_decode_threshold = 1.8 # derived spec decoding tests need to override this
|
spec_decode_threshold = 1.8 # derived spec decoding tests need to override this
|
||||||
Reference in New Issue
Block a user