[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -102,6 +102,8 @@ dev = ["sglang[test]"]
|
||||
"srt/layers/moe/fused_moe_triton/configs/*/*.json",
|
||||
"srt/layers/quantization/configs/*.json",
|
||||
"srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp",
|
||||
"srt/speculative/cpp_lookahead/*.cpp",
|
||||
"srt/speculative/cpp_lookahead/*.h",
|
||||
]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
|
||||
@@ -1110,7 +1110,8 @@ def sample_sharegpt_requests(
|
||||
add_generation_prompt=True,
|
||||
tokenize=False,
|
||||
)
|
||||
prompt = prompt.replace(tokenizer.bos_token, "")
|
||||
if tokenizer.bos_token:
|
||||
prompt = prompt.replace(tokenizer.bos_token, "")
|
||||
|
||||
prompt_token_ids = tokenizer.encode(prompt)
|
||||
completion = dataset[i][1]
|
||||
|
||||
@@ -29,6 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_sm100_supported,
|
||||
@@ -317,7 +318,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -422,7 +425,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -638,7 +643,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -651,7 +658,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -673,7 +682,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
assert self.sliding_window_size is not None
|
||||
for wrapper_id in range(2):
|
||||
@@ -721,7 +732,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -753,7 +766,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
@@ -858,7 +873,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -873,7 +890,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
if use_ragged:
|
||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||
@@ -909,7 +928,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -955,7 +976,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -997,7 +1020,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
@@ -1024,8 +1049,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
assert isinstance(
|
||||
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
|
||||
)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
|
||||
@@ -74,6 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
@@ -950,7 +951,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
@@ -1600,7 +1603,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
bs = len(self.reqs)
|
||||
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
if (
|
||||
self.spec_algorithm.is_eagle()
|
||||
or self.spec_algorithm.is_standalone()
|
||||
or self.spec_algorithm.is_lookahead()
|
||||
):
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
@@ -1975,7 +1982,9 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
spec_info: Optional[
|
||||
Union[EagleVerifyInput, EagleDraftInput, LookaheadVerifyInput]
|
||||
] = None
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
hicache_consumer_index: int = -1
|
||||
|
||||
@@ -385,6 +385,18 @@ class Scheduler(
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_lookahead():
|
||||
from sglang.srt.speculative.lookahead_worker import LOOKAHEADWorker
|
||||
|
||||
self.draft_worker = LOOKAHEADWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
@@ -740,8 +752,8 @@ class Scheduler(
|
||||
else (
|
||||
server_args.speculative_num_draft_tokens
|
||||
+ (
|
||||
server_args.speculative_eagle_topk
|
||||
* server_args.speculative_num_steps
|
||||
(server_args.speculative_eagle_topk or 1)
|
||||
* (server_args.speculative_num_steps or 1)
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -784,7 +796,7 @@ class Scheduler(
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
draft_token_to_kv_pool=(
|
||||
None
|
||||
if self.draft_worker is None
|
||||
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
@@ -821,7 +833,7 @@ class Scheduler(
|
||||
token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
|
||||
draft_token_to_kv_pool=(
|
||||
None
|
||||
if self.draft_worker is None
|
||||
if self.draft_worker is None or self.spec_algorithm.is_lookahead()
|
||||
else self.draft_worker.model_runner.token_to_kv_pool
|
||||
),
|
||||
req_to_metadata_buffer_idx_allocator=self.req_to_metadata_buffer_idx_allocator,
|
||||
@@ -2358,9 +2370,8 @@ class Scheduler(
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool_allocator.clear()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
||||
if self.draft_worker:
|
||||
self.draft_worker.clear_cache_pool()
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
self.forward_ct_decode = 0
|
||||
|
||||
@@ -84,6 +84,7 @@ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicat
|
||||
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.tracing.trace import (
|
||||
trace_get_proc_propagate_context,
|
||||
trace_req_finish,
|
||||
@@ -174,6 +175,15 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
self.image_token_id = self.model_config.image_token_id
|
||||
self.max_req_input_len = None # Will be set later in engine.py
|
||||
|
||||
speculative_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.reserve_input_token_num = (
|
||||
0
|
||||
if speculative_algorithm.is_none()
|
||||
else server_args.speculative_num_draft_tokens
|
||||
)
|
||||
|
||||
if self.model_config.is_multimodal:
|
||||
import_processors()
|
||||
try:
|
||||
@@ -618,6 +628,7 @@ class TokenizerManager(TokenizerCommunicatorMixin):
|
||||
_max_req_len = self.context_len
|
||||
|
||||
input_token_num = len(input_ids) if input_ids is not None else 0
|
||||
input_token_num += self.reserve_input_token_num
|
||||
if input_token_num >= self.context_len:
|
||||
if self.server_args.allow_auto_truncate:
|
||||
logger.warning(
|
||||
|
||||
@@ -275,6 +275,7 @@ class CudaGraphRunner:
|
||||
if (
|
||||
model_runner.spec_algorithm.is_eagle()
|
||||
or model_runner.spec_algorithm.is_standalone()
|
||||
or model_runner.spec_algorithm.is_lookahead()
|
||||
):
|
||||
if self.model_runner.is_draft_worker:
|
||||
raise RuntimeError("This should not happen")
|
||||
@@ -441,11 +442,21 @@ class CudaGraphRunner:
|
||||
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
||||
)
|
||||
|
||||
is_lookahead_supported = (
|
||||
(
|
||||
forward_batch.batch_size * self.num_tokens_per_bs
|
||||
== forward_batch.input_ids.numel()
|
||||
)
|
||||
if self.model_runner.spec_algorithm.is_lookahead()
|
||||
else True
|
||||
)
|
||||
|
||||
return (
|
||||
is_bs_supported
|
||||
and is_encoder_lens_supported
|
||||
and is_tbo_supported
|
||||
and capture_hidden_mode_matches
|
||||
and is_lookahead_supported
|
||||
)
|
||||
|
||||
def capture(self) -> None:
|
||||
@@ -856,6 +867,20 @@ class CudaGraphRunner:
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
elif self.model_runner.spec_algorithm.is_lookahead():
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
|
||||
spec_info = LookaheadVerifyInput(
|
||||
draft_token=None,
|
||||
tree_mask=self.custom_mask,
|
||||
positions=None,
|
||||
retrive_index=None,
|
||||
retrive_next_token=None,
|
||||
retrive_next_sibling=None,
|
||||
draft_token_num=self.num_tokens_per_bs,
|
||||
)
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.NULL
|
||||
|
||||
return spec_info
|
||||
|
||||
|
||||
|
||||
@@ -1402,7 +1402,7 @@ class ModelRunner:
|
||||
if self.is_hybrid_gdn:
|
||||
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
max_num_reqs = self.server_args.max_num_reqs
|
||||
|
||||
@@ -286,6 +286,14 @@ class ServerArgs:
|
||||
speculative_accept_threshold_acc: float = 1.0
|
||||
speculative_token_map: Optional[str] = None
|
||||
speculative_attention_mode: str = "prefill"
|
||||
# For lookahead only
|
||||
speculative_lookahead_min_match_window_size: int = 1
|
||||
speculative_lookahead_max_match_window_size: int = 12
|
||||
speculative_lookahead_min_bfs_breadth: int = 1
|
||||
speculative_lookahead_max_bfs_breadth: int = 10
|
||||
speculative_lookahead_match_type: Literal["BFS", "PROB"] = "BFS"
|
||||
speculative_lookahead_branch_length: int = 18
|
||||
speculative_lookahead_capacity: int = 10 * 1000 * 1000
|
||||
|
||||
# Expert parallelism
|
||||
ep_size: int = 1
|
||||
@@ -529,7 +537,7 @@ class ServerArgs:
|
||||
# Standalone speculative decoding needs more memory than other speculative
|
||||
# decoding algorithms since the draft model is typically larger.
|
||||
reserved_mem += 6 * 1024
|
||||
else:
|
||||
elif self.speculative_algorithm != "LOOKAHEAD":
|
||||
reserved_mem += 2 * 1024
|
||||
if self.enable_dp_attention:
|
||||
reserved_mem += 4 * 1024
|
||||
@@ -780,11 +788,11 @@ class ServerArgs:
|
||||
self.speculative_algorithm = "EAGLE"
|
||||
|
||||
if self.speculative_algorithm in ("EAGLE", "EAGLE3", "STANDALONE"):
|
||||
if self.speculative_algorithm == "STANDALONE":
|
||||
if self.speculative_algorithm == "STANDALONE" and self.enable_dp_attention:
|
||||
# TODO: support dp attention for standalone speculative decoding
|
||||
assert (
|
||||
self.enable_dp_attention is False
|
||||
), "Currently standalone speculative decoding does not support dp attention."
|
||||
raise ValueError(
|
||||
"Currently standalone speculative decoding does not support dp attention."
|
||||
)
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
@@ -858,6 +866,39 @@ class ServerArgs:
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||
|
||||
if self.speculative_algorithm == "LOOKAHEAD":
|
||||
if not self.device.startswith("cuda"):
|
||||
raise ValueError(
|
||||
"Lookahead speculative decoding only supports CUDA device."
|
||||
)
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
self.enable_mixed_chunk = False
|
||||
self.speculative_eagle_topk = self.speculative_lookahead_max_bfs_breadth
|
||||
if self.speculative_num_draft_tokens is None:
|
||||
# TODO: Do better auto choose in the future
|
||||
self.speculative_num_draft_tokens = (
|
||||
self.speculative_lookahead_max_match_window_size
|
||||
)
|
||||
logger.warning(
|
||||
"The overlap scheduler and mixed chunked prefill are disabled because of "
|
||||
"using lookahead speculative decoding."
|
||||
)
|
||||
if (
|
||||
self.speculative_eagle_topk > 1
|
||||
and self.page_size > 1
|
||||
and self.attention_backend != "flashinfer"
|
||||
):
|
||||
raise ValueError(
|
||||
"speculative_eagle_topk > 1 with page_size > 1 is unstable and produces incorrect results for paged attention backends. This combination is only supported for the 'flashinfer' backend."
|
||||
)
|
||||
|
||||
if self.enable_dp_attention:
|
||||
# TODO: support dp attention for lookahead speculative decoding
|
||||
raise ValueError(
|
||||
"Currently lookahead speculative decoding does not support dp attention."
|
||||
)
|
||||
# GGUF
|
||||
if (
|
||||
self.load_format == "auto" or self.load_format == "gguf"
|
||||
@@ -1690,7 +1731,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--speculative-algorithm",
|
||||
type=str,
|
||||
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE"],
|
||||
choices=["EAGLE", "EAGLE3", "NEXTN", "STANDALONE", "LOOKAHEAD"],
|
||||
help="Speculative algorithm.",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -1750,6 +1791,50 @@ class ServerArgs:
|
||||
help="Attention backend for speculative decoding operations (both target verify and draft extend). Can be one of 'prefill' (default) or 'decode'.",
|
||||
default=ServerArgs.speculative_attention_mode,
|
||||
)
|
||||
# Lookahead speculative decoding
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-min-match-window-size",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_min_match_window_size,
|
||||
help="The minimum window size for pattern matching in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-max-match-window-size",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_max_match_window_size,
|
||||
help="The maximum window size for pattern matching in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-min-bfs-breadth",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_min_bfs_breadth,
|
||||
help="The minimum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-max-bfs-breadth",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_max_bfs_breadth,
|
||||
help="The maximum breadth for BFS (Breadth-First Search) in lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-match-type",
|
||||
type=str,
|
||||
choices=["BFS", "PROB"],
|
||||
default=ServerArgs.speculative_lookahead_match_type,
|
||||
help="The match type for cache tree.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-branch-length",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_branch_length,
|
||||
help="The branch length for lookahead speculative decoding.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--speculative-lookahead-capacity",
|
||||
type=int,
|
||||
default=ServerArgs.speculative_lookahead_capacity,
|
||||
help="The cache capacity for lookahead speculative decoding.",
|
||||
)
|
||||
|
||||
# Expert parallelism
|
||||
parser.add_argument(
|
||||
|
||||
1
python/sglang/srt/speculative/cpp_lookahead/.clang-format
Symbolic link
1
python/sglang/srt/speculative/cpp_lookahead/.clang-format
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../../../sgl-kernel/.clang-format
|
||||
372
python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
Normal file
372
python/sglang/srt/speculative/cpp_lookahead/lookahead.cpp
Normal file
@@ -0,0 +1,372 @@
|
||||
#include "lookahead.h"
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
namespace lookahead {
|
||||
|
||||
struct Node {
|
||||
std::unordered_map<int32_t, int32_t> next;
|
||||
};
|
||||
|
||||
Lookahead::Result fillResult(int last_token, int draft_token_num, std::vector<Node>& tree, int root) {
|
||||
Lookahead::Result info;
|
||||
std::vector<int32_t> prevs;
|
||||
info.token.reserve(draft_token_num);
|
||||
prevs.reserve(draft_token_num);
|
||||
std::queue<std::tuple<int32_t, int32_t, int32_t>> queue;
|
||||
info.token.emplace_back(last_token);
|
||||
prevs.emplace_back(-1);
|
||||
|
||||
for (auto [token, next] : tree[root].next) {
|
||||
queue.emplace(token, next, 0);
|
||||
}
|
||||
while (queue.size()) {
|
||||
auto [token, next, prev] = queue.front();
|
||||
queue.pop();
|
||||
info.token.emplace_back(token);
|
||||
prevs.emplace_back(prev);
|
||||
for (auto [t, n] : tree[next].next) {
|
||||
queue.emplace(t, n, info.token.size() - 1);
|
||||
}
|
||||
}
|
||||
|
||||
// zero padding to length
|
||||
while (info.token.size() < draft_token_num) {
|
||||
info.token.emplace_back(0);
|
||||
prevs.emplace_back(0);
|
||||
}
|
||||
|
||||
int n = info.token.size();
|
||||
info.mask.resize(n * n, 0);
|
||||
info.mask[0] = 1;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
if (prevs[i] != -1) {
|
||||
memcpy(&info.mask[i * n], &info.mask[prevs[i] * n], prevs[i] + 1);
|
||||
}
|
||||
info.mask[i * n + i] = 1;
|
||||
}
|
||||
|
||||
return info;
|
||||
}
|
||||
|
||||
Lookahead::Lookahead(size_t capacity, const Param& param) {
|
||||
param_ = param;
|
||||
nodes_.resize(capacity);
|
||||
for (auto& node : nodes_) {
|
||||
node_pool_.emplace_back(&node);
|
||||
}
|
||||
free_node_count_ = node_pool_.size();
|
||||
root_ = getNode();
|
||||
|
||||
if (!(param_.branch_length > 1)) {
|
||||
throw std::runtime_error(
|
||||
"param_.branch_length must be greater than 1, current value: " + std::to_string(param_.branch_length));
|
||||
}
|
||||
if (!(param_.min_match_window_size > 0)) {
|
||||
throw std::runtime_error(
|
||||
"min_match_window_size must be greater than 0, current value: " + std::to_string(param_.min_match_window_size));
|
||||
}
|
||||
if (!(param_.min_match_window_size <= param_.max_match_window_size)) {
|
||||
throw std::runtime_error(
|
||||
"min_match_window_size must be less than or equal to max_match_window_size, current min_match_window_size: " +
|
||||
std::to_string(param_.min_match_window_size) +
|
||||
", max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
||||
}
|
||||
if (!(param_.max_match_window_size < param_.branch_length)) {
|
||||
throw std::runtime_error(
|
||||
"max_match_window_size must be less than branch_length, current max_match_window_size: " +
|
||||
std::to_string(param_.max_match_window_size) + ", branch_length: " + std::to_string(param_.branch_length));
|
||||
}
|
||||
if (!(param_.min_bfs_breadth > 0)) {
|
||||
throw std::runtime_error(
|
||||
"min_bfs_breadth must be greater than 0, current value: " + std::to_string(param_.min_bfs_breadth));
|
||||
}
|
||||
if (!(param_.min_bfs_breadth <= param_.max_bfs_breadth)) {
|
||||
throw std::runtime_error(
|
||||
"min_bfs_breadth must be less than or equal to max_bfs_breadth, current min_bfs_breadth: " +
|
||||
std::to_string(param_.min_bfs_breadth) + ", max_bfs_breadth: " + std::to_string(param_.max_bfs_breadth));
|
||||
}
|
||||
if (!(param_.draft_token_num > 0)) {
|
||||
throw std::runtime_error(
|
||||
"draft_token_num must be greater than 0, current value: " + std::to_string(param_.draft_token_num));
|
||||
}
|
||||
for (auto config : param_.batch_draft_token_num) {
|
||||
if (config != std::numeric_limits<decltype(config)>::max()) {
|
||||
if (!(config <= param_.draft_token_num)) {
|
||||
throw std::runtime_error(
|
||||
"batch_draft_token_num config value " + std::to_string(config) +
|
||||
" must be less than or equal to draft_token_num: " + std::to_string(param_.draft_token_num));
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto config : param_.batch_min_match_window_size) {
|
||||
if (config != std::numeric_limits<decltype(config)>::max()) {
|
||||
if (!(config >= param_.min_match_window_size)) {
|
||||
throw std::runtime_error(
|
||||
"batch_min_match_window_size config value " + std::to_string(config) +
|
||||
" must be greater than or equal to min_match_window_size: " + std::to_string(param_.min_match_window_size));
|
||||
}
|
||||
if (!(config <= param_.max_match_window_size)) {
|
||||
throw std::runtime_error(
|
||||
"batch_min_match_window_size config value " + std::to_string(config) +
|
||||
" must be less than or equal to max_match_window_size: " + std::to_string(param_.max_match_window_size));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
quit_flag_ = false;
|
||||
insert_worker_ = std::thread(&Lookahead::insert, this);
|
||||
}
|
||||
|
||||
Lookahead::~Lookahead() {
|
||||
quit_flag_ = true;
|
||||
insert_queue_.close();
|
||||
insert_worker_.join();
|
||||
}
|
||||
|
||||
std::vector<std::pair<TrieNode*, int32_t>>
|
||||
Lookahead::match(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||
auto min_match_window_size = param_.get_min_match_window_size(batch_size);
|
||||
auto max_match_window_size = param_.max_match_window_size;
|
||||
std::vector<std::pair<TrieNode*, int32_t>> result;
|
||||
result.reserve(param_.max_match_window_size - param_.min_match_window_size);
|
||||
for (int32_t match_window_size = std::min(tokens.size(), param_.max_match_window_size);
|
||||
match_window_size >= param_.min_match_window_size;
|
||||
--match_window_size) {
|
||||
auto start = tokens.data() + tokens.size() - match_window_size;
|
||||
auto end = start + match_window_size;
|
||||
auto cursor = root_;
|
||||
while (start != end) {
|
||||
auto iter = cursor->child.find(*start);
|
||||
if (iter == cursor->child.end()) {
|
||||
cursor = nullptr;
|
||||
break;
|
||||
}
|
||||
++start;
|
||||
cursor = iter->second;
|
||||
}
|
||||
if (cursor) {
|
||||
result.emplace_back(std::make_pair(cursor, match_window_size));
|
||||
}
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void Lookahead::squeeze(size_t count) {
|
||||
if (!(node_pool_.size() >= free_node_count_ + count)) {
|
||||
throw std::runtime_error(
|
||||
"Insufficient node size to release required nodes. "
|
||||
"available to release: " +
|
||||
std::to_string(node_pool_.size() - free_node_count_) + ", required to release: " + std::to_string(count));
|
||||
}
|
||||
while (count--) {
|
||||
auto last = global_lru_.back();
|
||||
global_lru_.pop_back();
|
||||
|
||||
if (!last->child.empty()) {
|
||||
throw std::runtime_error("The node to be released still has child nodes and cannot be released. ");
|
||||
}
|
||||
|
||||
last->parent->lru.erase(last->parent_lru_pos);
|
||||
last->parent->sorted_children.erase(last);
|
||||
last->parent->child.erase(last->token);
|
||||
|
||||
node_pool_[free_node_count_++] = last;
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::synchronize() const {
|
||||
while (!insert_queue_.empty()) {
|
||||
std::this_thread::sleep_for(std::chrono::microseconds(10));
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::insert() {
|
||||
while (!quit_flag_) {
|
||||
std::vector<int32_t> data;
|
||||
if (!insert_queue_.dequeue(data)) {
|
||||
continue;
|
||||
}
|
||||
const auto* token = data.data();
|
||||
size_t size = data.size();
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
for (size_t i = 0; i + param_.min_match_window_size < size; ++i) {
|
||||
auto start = token + i;
|
||||
auto end = start + std::min(size - i, param_.branch_length);
|
||||
|
||||
if (end - start > free_node_count_) {
|
||||
squeeze(end - start - free_node_count_);
|
||||
}
|
||||
|
||||
TrieNode* cursor = root_;
|
||||
path_.clear();
|
||||
while (start != end) {
|
||||
auto token = *start;
|
||||
auto iter = cursor->child.find(token);
|
||||
if (iter == cursor->child.end()) {
|
||||
iter = cursor->child.insert({token, getNode()}).first;
|
||||
auto node = iter->second;
|
||||
|
||||
cursor->lru.emplace_front(node);
|
||||
global_lru_.emplace_back(node);
|
||||
|
||||
node->token = token;
|
||||
node->parent = cursor;
|
||||
node->parent_lru_pos = cursor->lru.begin();
|
||||
node->global_lru_pos = --global_lru_.end();
|
||||
node->freq = 1;
|
||||
cursor->sorted_children.insert(node);
|
||||
} else {
|
||||
auto node = iter->second;
|
||||
cursor->sorted_children.erase(node);
|
||||
node->freq++;
|
||||
cursor->sorted_children.insert(node);
|
||||
cursor->lru.splice(cursor->lru.begin(), cursor->lru, node->parent_lru_pos);
|
||||
}
|
||||
cursor = iter->second;
|
||||
path_.emplace_back(cursor);
|
||||
++start;
|
||||
}
|
||||
|
||||
for (auto it = path_.rbegin(); it != path_.rend(); ++it) {
|
||||
TrieNode* node = *it;
|
||||
global_lru_.splice(global_lru_.begin(), global_lru_, node->global_lru_pos);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void Lookahead::asyncInsert(std::vector<std::vector<int32_t>>&& tokens) {
|
||||
for (auto&& token : tokens) {
|
||||
insert_queue_.enqueue(std::move(token));
|
||||
}
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||
|
||||
double bfs_breadth_scale = double(param_.max_bfs_breadth - param_.min_bfs_breadth) /
|
||||
(param_.max_match_window_size - param_.min_match_window_size + 1);
|
||||
|
||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||
std::vector<Node> tree(draft_token_num + 1);
|
||||
int root = 0;
|
||||
int cursor = 1;
|
||||
|
||||
for (auto [node, depth] : nodes) {
|
||||
std::queue<std::tuple<int32_t, double, const TrieNode*>> queue; // parent, bfs_breadth, node
|
||||
queue.push({root, (param_.max_match_window_size - depth) * bfs_breadth_scale + param_.min_bfs_breadth, node});
|
||||
while (queue.size() && cursor <= draft_token_num) {
|
||||
auto front = queue.front();
|
||||
queue.pop();
|
||||
|
||||
auto parent = std::get<0>(front);
|
||||
auto cur_breadth = std::get<1>(front);
|
||||
auto iter = std::get<2>(front)->lru.begin();
|
||||
|
||||
auto breadth = std::max(1, int32_t(cur_breadth));
|
||||
for (int i = 0; i < breadth && iter != std::get<2>(front)->lru.end() && cursor <= draft_token_num; ++i, ++iter) {
|
||||
auto token = (*iter)->token;
|
||||
auto pos = -1;
|
||||
if (auto tit = tree[parent].next.find(token); tit != tree[parent].next.end()) {
|
||||
pos = tit->second;
|
||||
} else {
|
||||
pos = tree[parent].next.insert(std::make_pair(token, cursor++)).first->second;
|
||||
}
|
||||
queue.emplace(pos, cur_breadth - bfs_breadth_scale, *iter);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const {
|
||||
std::vector<std::pair<TrieNode*, int32_t>> nodes = match(tokens, batch_size);
|
||||
auto draft_token_num = param_.get_draft_token_num(batch_size);
|
||||
|
||||
struct CompareByLastDouble {
|
||||
bool operator()(
|
||||
const std::tuple<double, const TrieNode*, double>& a, // parent_pos, node, final_prob
|
||||
const std::tuple<double, const TrieNode*, double>& b) const {
|
||||
return std::get<2>(a) < std::get<2>(b);
|
||||
}
|
||||
};
|
||||
|
||||
std::priority_queue<
|
||||
std::tuple<double, const TrieNode*, double>,
|
||||
std::vector<std::tuple<double, const TrieNode*, double>>,
|
||||
CompareByLastDouble>
|
||||
heap;
|
||||
|
||||
std::vector<Node> tree(draft_token_num + 1);
|
||||
|
||||
int root = 0;
|
||||
int cursor = 1;
|
||||
int top_k = param_.max_bfs_breadth;
|
||||
|
||||
auto addToHeap = [&heap, &top_k](int parent, const TrieNode* trie_node, double prob) -> void {
|
||||
double sum_freq = 0.0;
|
||||
int count = 0;
|
||||
std::list<std::pair<TrieNode*, int32_t>> topk_children;
|
||||
for (auto* child : trie_node->sorted_children) {
|
||||
sum_freq += static_cast<double>(child->freq);
|
||||
topk_children.emplace_back(child, child->freq);
|
||||
if (++count >= top_k) break;
|
||||
}
|
||||
if (sum_freq <= 0) sum_freq = 1.0;
|
||||
for (const auto& [child, freq] : topk_children) {
|
||||
double norm_freq = static_cast<double>(freq) / sum_freq * prob;
|
||||
heap.emplace(parent, child, norm_freq);
|
||||
}
|
||||
};
|
||||
|
||||
for (auto [node, _] : nodes) {
|
||||
addToHeap(root, node, 1.0);
|
||||
|
||||
while (!heap.empty() && cursor <= draft_token_num) {
|
||||
auto [parent, trie_node, prob] = heap.top(); // parent_pos, node, final_prob
|
||||
heap.pop();
|
||||
auto token = trie_node->token;
|
||||
int pos = -1;
|
||||
auto tit = tree[parent].next.find(token);
|
||||
if (tit != tree[parent].next.end()) {
|
||||
pos = tit->second;
|
||||
} else {
|
||||
pos = cursor++;
|
||||
tree[parent].next[token] = pos;
|
||||
}
|
||||
addToHeap(pos, trie_node, prob);
|
||||
}
|
||||
}
|
||||
|
||||
return fillResult(tokens.back(), draft_token_num + 1, tree, root);
|
||||
}
|
||||
|
||||
Lookahead::Result Lookahead::batchMatch(const std::vector<std::vector<int32_t>>& tokens) const {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
Result merged_result;
|
||||
auto match_func = param_.match_type == "BFS" ? &Lookahead::matchBFS : &Lookahead::matchProb;
|
||||
for (const auto& tks : tokens) {
|
||||
Result res = (this->*match_func)(tks, tokens.size());
|
||||
merged_result.token.insert(merged_result.token.end(), res.token.begin(), res.token.end());
|
||||
merged_result.mask.insert(merged_result.mask.end(), res.mask.begin(), res.mask.end());
|
||||
}
|
||||
return merged_result;
|
||||
}
|
||||
|
||||
void Lookahead::Result::truncate(size_t n) {
|
||||
if (n < token.size()) {
|
||||
int full_n = token.size();
|
||||
for (int i = 1; i < n; ++i) {
|
||||
memcpy(&mask[i * n], &mask[i * full_n], sizeof(mask[0]) * n);
|
||||
}
|
||||
token.resize(n);
|
||||
mask.resize(n * n);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace lookahead
|
||||
110
python/sglang/srt/speculative/cpp_lookahead/lookahead.h
Normal file
110
python/sglang/srt/speculative/cpp_lookahead/lookahead.h
Normal file
@@ -0,0 +1,110 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <list>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <sstream>
|
||||
#include <thread>
|
||||
#include <tuple>
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "param.h"
|
||||
#include "queue.h"
|
||||
|
||||
namespace lookahead {
|
||||
|
||||
struct TrieNode {
|
||||
std::unordered_map<int32_t, TrieNode*> child;
|
||||
std::list<TrieNode*>::const_iterator global_lru_pos;
|
||||
std::list<TrieNode*>::const_iterator parent_lru_pos;
|
||||
int32_t token;
|
||||
TrieNode* parent;
|
||||
std::list<TrieNode*> lru;
|
||||
int32_t freq = 0;
|
||||
|
||||
struct CompareByFreq {
|
||||
bool operator()(TrieNode* a, TrieNode* b) const {
|
||||
return std::tie(b->freq, a->token, a) < std::tie(a->freq, b->token, b);
|
||||
}
|
||||
};
|
||||
std::multiset<TrieNode*, CompareByFreq> sorted_children;
|
||||
};
|
||||
|
||||
class Lookahead {
|
||||
std::vector<TrieNode> nodes_;
|
||||
std::vector<TrieNode*> node_pool_;
|
||||
size_t free_node_count_;
|
||||
std::list<TrieNode*> global_lru_;
|
||||
TrieNode* root_;
|
||||
std::vector<TrieNode*> path_;
|
||||
Param param_;
|
||||
|
||||
std::vector<std::pair<TrieNode*, int32_t>> match(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
||||
|
||||
void squeeze(size_t count);
|
||||
|
||||
TrieNode* getNode() {
|
||||
auto node = node_pool_[--free_node_count_];
|
||||
node->~TrieNode();
|
||||
new (node) TrieNode();
|
||||
return node;
|
||||
}
|
||||
|
||||
mutable std::mutex mutex_;
|
||||
bool quit_flag_;
|
||||
utils::Queue<std::vector<int32_t>> insert_queue_;
|
||||
std::thread insert_worker_;
|
||||
std::vector<std::tuple<int32_t, int32_t, int32_t, int32_t>> match_tmp_data_;
|
||||
|
||||
public:
|
||||
Lookahead(size_t capacity, const Param& param);
|
||||
Lookahead() = default;
|
||||
~Lookahead();
|
||||
|
||||
static Lookahead& instance() {
|
||||
static Lookahead instance;
|
||||
return instance;
|
||||
}
|
||||
|
||||
void synchronize() const;
|
||||
|
||||
void asyncInsert(std::vector<std::vector<int32_t>>&& tokens);
|
||||
|
||||
struct Result {
|
||||
std::vector<int32_t> token;
|
||||
std::vector<uint8_t> mask;
|
||||
|
||||
void truncate(size_t n);
|
||||
};
|
||||
|
||||
Result batchMatch(const std::vector<std::vector<int32_t>>& tokens) const;
|
||||
|
||||
void reset() {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
|
||||
global_lru_.clear();
|
||||
path_.clear();
|
||||
node_pool_.clear();
|
||||
for (auto& node : nodes_) {
|
||||
node_pool_.emplace_back(&node);
|
||||
}
|
||||
free_node_count_ = node_pool_.size();
|
||||
root_ = getNode();
|
||||
}
|
||||
|
||||
const Param& param() const {
|
||||
return param_;
|
||||
}
|
||||
|
||||
private:
|
||||
Result matchBFS(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
||||
Result matchProb(const std::vector<int32_t>& tokens, size_t batch_size) const;
|
||||
|
||||
void insert();
|
||||
};
|
||||
|
||||
} // namespace lookahead
|
||||
140
python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
Normal file
140
python/sglang/srt/speculative/cpp_lookahead/lookahead_cache.py
Normal file
@@ -0,0 +1,140 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
# from sglang.op.lookahead import Lookahead, Param
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Tuple
|
||||
|
||||
import numpy as np
|
||||
from torch.utils.cpp_extension import load
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_abs_path = os.path.dirname(os.path.abspath(__file__))
|
||||
lookahead_cache_cpp = load(
|
||||
name="lookahead_cache_cpp",
|
||||
sources=[
|
||||
f"{_abs_path}/lookahead_cache_binding.cpp",
|
||||
f"{_abs_path}/lookahead.cpp",
|
||||
],
|
||||
extra_cflags=["-O3", "-std=c++20"],
|
||||
)
|
||||
|
||||
|
||||
class LookaheadCache:
|
||||
def __init__(
|
||||
self,
|
||||
branch_length=18,
|
||||
min_match_window_size=1,
|
||||
max_match_window_size=10,
|
||||
min_bfs_breadth=1,
|
||||
max_bfs_breadth=8,
|
||||
draft_token_num=8,
|
||||
match_type="BFS",
|
||||
capacity=1000000,
|
||||
):
|
||||
param = lookahead_cache_cpp.Param()
|
||||
param.branch_length = branch_length
|
||||
param.min_match_window_size = min_match_window_size
|
||||
param.max_match_window_size = max_match_window_size
|
||||
param.min_bfs_breadth = min_bfs_breadth
|
||||
param.max_bfs_breadth = max_bfs_breadth
|
||||
param.draft_token_num = draft_token_num
|
||||
param.match_type = match_type
|
||||
self.cache = lookahead_cache_cpp.Lookahead(capacity, param)
|
||||
|
||||
self.default_mask = np.ones((1, 1), dtype=np.int64)
|
||||
self.draft_token_num = draft_token_num
|
||||
|
||||
def batch_put(self, batch_tokens: List[List[int]]):
|
||||
self.cache.asyncInsert(batch_tokens)
|
||||
|
||||
def synchronize(self):
|
||||
self.cache.synchronize()
|
||||
|
||||
def reset(self):
|
||||
self.cache.reset()
|
||||
|
||||
def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
|
||||
result = self.cache.batchMatch(batch_tokens)
|
||||
return np.array(result.token), np.array(result.mask)
|
||||
|
||||
def leaf_paths_from_mask(
|
||||
self, tokens: List[int], tree_mask: List[List[int]]
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
|
||||
|
||||
Args:
|
||||
mask : List[List[int]] # nxn binary matrix
|
||||
tokens : List[int] # token list corresponding to columns
|
||||
|
||||
Returns:
|
||||
List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
|
||||
"""
|
||||
|
||||
row_sets = [
|
||||
(i, {idx for idx, v in enumerate(row) if v == 1})
|
||||
for i, row in enumerate(tree_mask)
|
||||
]
|
||||
leaf_sets = []
|
||||
leaf_rows = []
|
||||
|
||||
for i, cur_set in reversed(row_sets):
|
||||
if any(cur_set <= kept for kept in leaf_sets):
|
||||
continue
|
||||
leaf_sets.append(cur_set)
|
||||
leaf_rows.append(i)
|
||||
|
||||
leaf_rows.reverse()
|
||||
result = []
|
||||
for r in leaf_rows:
|
||||
path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
|
||||
result.append(path)
|
||||
|
||||
return result
|
||||
|
||||
def debug_result(
|
||||
self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
|
||||
):
|
||||
decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
|
||||
decoding_masks = decoding_masks.reshape(
|
||||
-1, self.draft_token_num, self.draft_token_num
|
||||
)
|
||||
logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
|
||||
for i in range(decoding_ids.shape[0]):
|
||||
leaf_paths = self.leaf_paths_from_mask(
|
||||
decoding_ids[i].tolist(), decoding_masks[i].tolist()
|
||||
)
|
||||
if tokenizer is None:
|
||||
logger.info(f"draft path {i}: {leaf_paths}")
|
||||
else:
|
||||
logger.info(f"result {i}:")
|
||||
for leaf_path in leaf_paths:
|
||||
logger.info(
|
||||
f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
|
||||
)
|
||||
|
||||
|
||||
# main function
|
||||
if __name__ == "__main__":
|
||||
format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
|
||||
logging.basicConfig(
|
||||
level=logging.DEBUG,
|
||||
format=format,
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
force=True,
|
||||
)
|
||||
|
||||
token_ids = [
|
||||
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
|
||||
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
|
||||
]
|
||||
cache = LookaheadCache(branch_length=12, draft_token_num=8)
|
||||
cache.batch_put(token_ids)
|
||||
|
||||
cache.synchronize()
|
||||
decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
|
||||
|
||||
cache.debug_result(decoding_ids, decoding_masks)
|
||||
@@ -0,0 +1,43 @@
|
||||
#include <pybind11/pybind11.h>
|
||||
#include <pybind11/stl.h>
|
||||
|
||||
#include "lookahead.h"
|
||||
|
||||
PYBIND11_MODULE(lookahead_cache_cpp, m) {
|
||||
using namespace lookahead;
|
||||
namespace py = pybind11;
|
||||
m.doc() = "";
|
||||
|
||||
py::class_<Lookahead>(m, "Lookahead")
|
||||
.def(py::init<size_t, const Param&>(), py::arg("capacity"), py::arg("param"))
|
||||
.def("asyncInsert", &Lookahead::asyncInsert, "")
|
||||
.def("batchMatch", &Lookahead::batchMatch, "")
|
||||
.def("reset", &Lookahead::reset, "")
|
||||
.def("synchronize", &Lookahead::synchronize, "");
|
||||
|
||||
py::class_<Param>(m, "Param")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("enable", &Param::enable)
|
||||
.def_readwrite("enable_router_mode", &Param::enable_router_mode)
|
||||
.def_readwrite("min_bfs_breadth", &Param::min_bfs_breadth)
|
||||
.def_readwrite("max_bfs_breadth", &Param::max_bfs_breadth)
|
||||
.def_readwrite("min_match_window_size", &Param::min_match_window_size)
|
||||
.def_readwrite("max_match_window_size", &Param::max_match_window_size)
|
||||
.def_readwrite("branch_length", &Param::branch_length)
|
||||
.def_readwrite("draft_token_num", &Param::draft_token_num)
|
||||
.def_readwrite("match_type", &Param::match_type)
|
||||
.def_readwrite("batch_min_match_window_size", &Param::batch_min_match_window_size)
|
||||
.def_readwrite("batch_draft_token_num", &Param::batch_draft_token_num)
|
||||
.def("get_draft_token_num", &Param::get_draft_token_num, "")
|
||||
.def("get_min_match_window_size", &Param::get_min_match_window_size, "")
|
||||
.def("parse", &Param::parse, "")
|
||||
.def("resetBatchMinMatchWindowSize", &Param::resetBatchMinMatchWindowSize, "")
|
||||
.def("resetBatchReturnTokenNum", &Param::resetBatchReturnTokenNum, "")
|
||||
.def("detail", &Param::detail, "");
|
||||
|
||||
py::class_<Lookahead::Result>(m, "Result")
|
||||
.def(py::init<>())
|
||||
.def_readwrite("token", &Lookahead::Result::token)
|
||||
.def_readwrite("mask", &Lookahead::Result::mask)
|
||||
.def("truncate", &Lookahead::Result::truncate);
|
||||
}
|
||||
125
python/sglang/srt/speculative/cpp_lookahead/param.h
Normal file
125
python/sglang/srt/speculative/cpp_lookahead/param.h
Normal file
@@ -0,0 +1,125 @@
|
||||
#pragma once
|
||||
|
||||
#include <cstddef>
|
||||
#include <iostream>
|
||||
#include <limits>
|
||||
#include <regex>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
namespace lookahead {
|
||||
|
||||
struct Param {
|
||||
bool enable;
|
||||
bool enable_router_mode;
|
||||
size_t min_bfs_breadth;
|
||||
size_t max_bfs_breadth;
|
||||
size_t min_match_window_size;
|
||||
size_t max_match_window_size;
|
||||
size_t branch_length;
|
||||
size_t draft_token_num;
|
||||
std::string match_type;
|
||||
|
||||
std::vector<size_t> batch_min_match_window_size;
|
||||
std::vector<size_t> batch_draft_token_num;
|
||||
|
||||
size_t get_draft_token_num(size_t batch_size) const {
|
||||
if (batch_size < batch_draft_token_num.size()) {
|
||||
if (batch_draft_token_num[batch_size] !=
|
||||
std::numeric_limits<decltype(batch_draft_token_num)::value_type>::max()) {
|
||||
return batch_draft_token_num[batch_size];
|
||||
}
|
||||
}
|
||||
return draft_token_num - 1;
|
||||
}
|
||||
|
||||
size_t get_min_match_window_size(size_t batch_size) const {
|
||||
if (batch_size < batch_min_match_window_size.size()) {
|
||||
if (batch_min_match_window_size[batch_size] !=
|
||||
std::numeric_limits<decltype(batch_min_match_window_size)::value_type>::max()) {
|
||||
return batch_min_match_window_size[batch_size];
|
||||
}
|
||||
}
|
||||
return min_match_window_size;
|
||||
}
|
||||
|
||||
std::vector<size_t> parse(const std::string& value) {
|
||||
// 0-1|10,2-3|20,
|
||||
std::vector<size_t> result;
|
||||
if (value.empty()) {
|
||||
return result;
|
||||
}
|
||||
std::vector<size_t> mark;
|
||||
std::regex comma_re(",");
|
||||
std::sregex_token_iterator first{value.begin(), value.end(), comma_re, -1}, last;
|
||||
for (auto p : std::vector<std::string>(first, last)) {
|
||||
std::cerr << "seg " << p << std::endl;
|
||||
}
|
||||
for (const auto& seg : std::vector<std::string>(first, last)) {
|
||||
std::regex pipe_re("\\|");
|
||||
std::sregex_token_iterator seg_first{seg.begin(), seg.end(), pipe_re, -1}, seg_last;
|
||||
std::vector<std::string> part(seg_first, seg_last);
|
||||
for (auto p : part) {
|
||||
std::cerr << "part " << p << std::endl;
|
||||
}
|
||||
if (part.size() != 2) {
|
||||
throw std::runtime_error(
|
||||
"failed to get config, invalid config: " + seg + ", part's size = " + std::to_string(part.size()));
|
||||
}
|
||||
std::regex endash_re("-");
|
||||
std::sregex_token_iterator range_first{part[0].begin(), part[0].end(), endash_re, -1}, range_last;
|
||||
std::vector<std::string> range(range_first, range_last);
|
||||
if (range.size() != 2) {
|
||||
throw std::runtime_error("failed to get range, invalid config: " + value);
|
||||
}
|
||||
size_t L = std::atoi(range[0].c_str());
|
||||
size_t R = std::atoi(range[1].c_str());
|
||||
if (L > R || R > 128) {
|
||||
throw std::runtime_error("invalid range, config: " + value);
|
||||
}
|
||||
if (R >= result.size()) {
|
||||
result.resize(R + 1, std::numeric_limits<decltype(result)::value_type>::max());
|
||||
mark.resize(result.size(), false);
|
||||
}
|
||||
size_t config = std::atoi(part[1].c_str());
|
||||
do {
|
||||
if (mark[L]) {
|
||||
throw std::runtime_error("repeated position " + std::to_string(L) + ", config : " + value);
|
||||
}
|
||||
mark[L] = true;
|
||||
result[L] = config;
|
||||
} while (++L <= R);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
void resetBatchMinMatchWindowSize(const std::string& value) {
|
||||
batch_min_match_window_size = parse(value);
|
||||
}
|
||||
|
||||
void resetBatchReturnTokenNum(const std::string& value) {
|
||||
batch_draft_token_num = parse(value);
|
||||
}
|
||||
|
||||
std::string detail() {
|
||||
std::stringstream ss;
|
||||
ss << "enable = " << enable << ", enable_router_mode = " << enable_router_mode
|
||||
<< ", min_bfs_breadth = " << min_bfs_breadth << ", max_bfs_breadth = " << max_bfs_breadth
|
||||
<< ", min_match_window_size = " << min_match_window_size << ", max_match_window_size = " << max_match_window_size
|
||||
<< ", branch_length = " << branch_length << ", draft_token_num = " << draft_token_num
|
||||
<< ", match_type = " << match_type;
|
||||
ss << ", batch_min_match_window_size(" << batch_min_match_window_size.size() << ") = ";
|
||||
for (int i = 0; i < batch_min_match_window_size.size(); ++i) {
|
||||
ss << i << "|" << batch_min_match_window_size[i] << ",";
|
||||
}
|
||||
ss << ", batch_draft_token_num(" << batch_draft_token_num.size() << ") = ";
|
||||
for (int i = 0; i < batch_draft_token_num.size(); ++i) {
|
||||
ss << i << "|" << batch_draft_token_num[i] << ",";
|
||||
}
|
||||
return ss.str();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace lookahead
|
||||
71
python/sglang/srt/speculative/cpp_lookahead/queue.h
Normal file
71
python/sglang/srt/speculative/cpp_lookahead/queue.h
Normal file
@@ -0,0 +1,71 @@
|
||||
#pragma once
|
||||
|
||||
#include <condition_variable>
|
||||
#include <queue>
|
||||
|
||||
namespace utils {
|
||||
|
||||
template <typename T>
|
||||
class Queue {
|
||||
public:
|
||||
bool enqueue(T&& rhs) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (closed_) {
|
||||
return false;
|
||||
}
|
||||
queue_.emplace(std::move(rhs));
|
||||
}
|
||||
cv_.notify_one();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool enqueue(const T& rhs) {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
if (closed_) {
|
||||
return false;
|
||||
}
|
||||
queue_.emplace(rhs);
|
||||
}
|
||||
cv_.notify_one();
|
||||
return true;
|
||||
}
|
||||
|
||||
bool dequeue(T& rhs) {
|
||||
std::unique_lock<std::mutex> lock(mutex_);
|
||||
cv_.wait(lock, [this] { return queue_.size() || closed_; });
|
||||
if (closed_) {
|
||||
return false;
|
||||
}
|
||||
rhs = std::move(queue_.front());
|
||||
queue_.pop();
|
||||
return true;
|
||||
}
|
||||
|
||||
size_t size() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.size();
|
||||
}
|
||||
|
||||
bool empty() const {
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
return queue_.empty();
|
||||
}
|
||||
|
||||
void close() {
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
closed_ = true;
|
||||
}
|
||||
cv_.notify_all();
|
||||
}
|
||||
|
||||
private:
|
||||
std::queue<T> queue_;
|
||||
mutable std::mutex mutex_;
|
||||
std::condition_variable cv_;
|
||||
bool closed_{false};
|
||||
};
|
||||
|
||||
} // namespace utils
|
||||
@@ -771,6 +771,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
|
||||
return score_list, token_list, parents_list
|
||||
|
||||
def clear_cache_pool(self):
|
||||
self.model_runner.req_to_token_pool.clear()
|
||||
self.model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
||||
spec_info.prepare_for_verify(batch, self.page_size)
|
||||
batch.return_hidden_states = False
|
||||
|
||||
412
python/sglang/srt/speculative/lookahead_utils.py
Normal file
412
python/sglang/srt/speculative/lookahead_utils.py
Normal file
@@ -0,0 +1,412 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
TREE_SPEC_KERNEL_AVAILABLE,
|
||||
assign_req_to_token_pool,
|
||||
create_flashinfer_kv_indices_triton,
|
||||
get_src_tgt_cache_loc,
|
||||
get_target_cache_loc,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
top_k_renorm_prob,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
|
||||
@dataclass
|
||||
class LookaheadVerifyInput:
|
||||
def __init__(
|
||||
self,
|
||||
draft_token: torch.Tensor,
|
||||
tree_mask: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
retrive_index: torch.Tensor,
|
||||
retrive_next_token: torch.Tensor,
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
draft_token_num: int,
|
||||
):
|
||||
self.draft_token = draft_token
|
||||
self.custom_mask = tree_mask
|
||||
self.positions = positions
|
||||
self.retrive_index = retrive_index
|
||||
self.retrive_next_token = retrive_next_token
|
||||
self.retrive_next_sibling = retrive_next_sibling
|
||||
self.draft_token_num = draft_token_num
|
||||
self.device = self.custom_mask.device
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
batch.input_ids = self.draft_token
|
||||
|
||||
if page_size == 1:
|
||||
batch.out_cache_loc = batch.alloc_token_slots(len(batch.input_ids))
|
||||
end_offset = batch.seq_lens + self.draft_token_num
|
||||
else:
|
||||
prefix_lens = batch.seq_lens
|
||||
end_offset = prefix_lens + self.draft_token_num
|
||||
last_loc = get_last_loc(
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.req_pool_indices,
|
||||
prefix_lens,
|
||||
)
|
||||
batch.out_cache_loc = batch.alloc_paged_token_slots_extend(
|
||||
prefix_lens, end_offset, last_loc, len(batch.input_ids)
|
||||
)
|
||||
self.last_loc = last_loc
|
||||
|
||||
bs = batch.batch_size()
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
end_offset,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
def generate_attn_arg_prefill(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
||||
|
||||
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
self.qo_indptr = (
|
||||
torch.arange(0, bs + 1, dtype=torch.int32, device=self.device)
|
||||
* self.draft_token_num
|
||||
)
|
||||
|
||||
kv_indices = torch.empty(
|
||||
cum_kv_seq_len[-1], dtype=torch.int32, device=self.device
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
cum_kv_seq_len,
|
||||
None,
|
||||
kv_indices,
|
||||
req_to_token.size(1),
|
||||
)
|
||||
return kv_indices, cum_kv_seq_len, self.qo_indptr, self.custom_mask
|
||||
|
||||
def _fill_requests(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: torch.Tensor,
|
||||
):
|
||||
accept_index_cpu = self.accept_index.tolist()
|
||||
predict_cpu = self.predict.tolist()
|
||||
has_finished = False
|
||||
|
||||
# Iterate every accepted token and check if req has finished after append the token
|
||||
# should be checked BEFORE free kv cache slots
|
||||
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
||||
for j, idx in enumerate(accept_index_row):
|
||||
if idx == -1:
|
||||
break
|
||||
id = predict_cpu[idx]
|
||||
req.output_ids.append(id)
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
has_finished = True
|
||||
# set all tokens after finished token to -1 and break
|
||||
self.accept_index[i, j + 1 :] = -1
|
||||
break
|
||||
else:
|
||||
if req.grammar is not None:
|
||||
try:
|
||||
req.grammar.accept_token(id)
|
||||
except ValueError as e:
|
||||
logger.info(
|
||||
f"{i=}, {req=}\n"
|
||||
f"{self.accept_index=}\n"
|
||||
f"{self.predict=}\n"
|
||||
)
|
||||
raise e
|
||||
req.spec_verify_ct += 1
|
||||
if has_finished:
|
||||
self.accept_length = (self.accept_index != -1).sum(dim=1) - 1
|
||||
self.accept_index = self.accept_index[self.accept_index != -1]
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||
self.accept_index
|
||||
]
|
||||
if logits_output.hidden_states:
|
||||
logits_output.hidden_states = logits_output.hidden_states[self.accept_index]
|
||||
self.verified_id = self.predict[self.accept_index]
|
||||
|
||||
def _free_cache(self, batch: ScheduleBatch, page_size: int):
|
||||
bs = batch.batch_size()
|
||||
# Free the KV cache for unaccepted tokens
|
||||
if page_size == 1:
|
||||
# TODO: boolean array index leads to a device sync. Remove it.
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[self.accept_index] = False
|
||||
batch.token_to_kv_pool_allocator.free(batch.out_cache_loc[evict_mask])
|
||||
batch.out_cache_loc = batch.out_cache_loc[self.accept_index]
|
||||
else:
|
||||
# Shift the accepted tokens to the beginning.
|
||||
# Only evict the last part
|
||||
src_cache_loc, tgt_cache_loc, to_free_num_slots = get_src_tgt_cache_loc(
|
||||
batch.seq_lens,
|
||||
batch.out_cache_loc,
|
||||
self.accept_index,
|
||||
self.accept_length,
|
||||
self.draft_token_num,
|
||||
page_size,
|
||||
)
|
||||
to_free_slots = torch.empty(
|
||||
(to_free_num_slots.sum().item(),),
|
||||
dtype=torch.int64,
|
||||
device=to_free_num_slots.device,
|
||||
)
|
||||
|
||||
# out_cache_loc: [0 1 2, 3 4 5, 6 7 8]
|
||||
# accept_index: [0 -1 2, 3 4 -1, 6 -1 -1]
|
||||
# tgt_cache_loc: [0 1 , 3 4 , 6 ]
|
||||
# to_free_slots: [ 2, 5, 7 8]
|
||||
# to_free_slots also needs to be page-aligned without the first partial page
|
||||
#
|
||||
# split each row of out_cache_loc into two parts.
|
||||
# 1. the first part goes to tgt_cache_loc. length = accept_length[i] + 1
|
||||
# 2. the second part goes to to_free_slots.
|
||||
get_target_cache_loc[(bs,)](
|
||||
tgt_cache_loc,
|
||||
to_free_slots,
|
||||
self.accept_length,
|
||||
to_free_num_slots,
|
||||
batch.out_cache_loc,
|
||||
self.draft_token_num,
|
||||
next_power_of_2(self.draft_token_num),
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
# Free the kv cache
|
||||
batch.token_to_kv_pool_allocator.free(to_free_slots)
|
||||
|
||||
# Copy the kv cache
|
||||
batch.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||
tgt_cache_loc, src_cache_loc
|
||||
)
|
||||
batch.out_cache_loc = tgt_cache_loc
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + self.accept_length + 1,
|
||||
batch.out_cache_loc,
|
||||
batch.req_to_token_pool.req_to_token.shape[1],
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
def _greedy_verify(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
):
|
||||
bs = batch.batch_size()
|
||||
target_predict = torch.argmax(logits_output.next_token_logits, dim=-1)
|
||||
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||
|
||||
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
||||
predict_shape[-1] += 1
|
||||
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
|
||||
self.accept_index = torch.full(
|
||||
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
|
||||
verify_tree_greedy(
|
||||
predicts=self.predict, # mutable
|
||||
accept_index=self.accept_index, # mutable
|
||||
accept_token_num=self.accept_length, # mutable
|
||||
candidates=candidates,
|
||||
retrive_index=self.retrive_index,
|
||||
retrive_next_token=self.retrive_next_token,
|
||||
retrive_next_sibling=self.retrive_next_sibling,
|
||||
target_predict=target_predict,
|
||||
)
|
||||
|
||||
def _sampling_verify(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
sampling_info: SamplingBatchInfo,
|
||||
):
|
||||
bs = batch.batch_size()
|
||||
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||
predict_shape = list(logits_output.next_token_logits.shape)[:-1]
|
||||
predict_shape[-1] += 1
|
||||
self.predict = torch.empty(predict_shape, dtype=torch.int32, device=self.device)
|
||||
self.accept_index = torch.full(
|
||||
(bs, self.draft_token_num), -1, dtype=torch.int32, device=self.device
|
||||
)
|
||||
self.accept_length = torch.empty((bs,), dtype=torch.int32, device=self.device)
|
||||
# apply temperature and get target probs
|
||||
expanded_temperature = torch.repeat_interleave(
|
||||
sampling_info.temperatures, self.draft_token_num, dim=0
|
||||
) # (bs * draft_token_num, 1)
|
||||
|
||||
target_probs = F.softmax(
|
||||
logits_output.next_token_logits / expanded_temperature, dim=-1
|
||||
) # (bs * draft_token_num, vocab_size)
|
||||
|
||||
# NOTE: The test shows that top_p_renorm_prob and top_k_renorm_prob are the key factors
|
||||
# contributing to the poor performance of _sampling_verify.
|
||||
target_probs = top_k_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(sampling_info.top_ks, self.draft_token_num, dim=0),
|
||||
) # (bs * draft_token_num, vocab_size)
|
||||
|
||||
if sampling_info.need_top_p_sampling:
|
||||
# logger.info("Using top-p sampling in speculative decoding verification.")
|
||||
target_probs = top_p_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(
|
||||
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||
),
|
||||
)
|
||||
|
||||
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
||||
draft_probs = torch.zeros(
|
||||
target_probs.shape, dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
# coins for rejection sampling
|
||||
coins = torch.rand_like(candidates, dtype=torch.float32, device=self.device)
|
||||
# coins for final sampling
|
||||
coins_for_final_sampling = torch.rand(
|
||||
(bs,), dtype=torch.float32, device=self.device
|
||||
)
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=self.predict, # mutable
|
||||
accept_index=self.accept_index, # mutable
|
||||
accept_token_num=self.accept_length, # mutable
|
||||
candidates=candidates.to(torch.int64),
|
||||
retrive_index=self.retrive_index.to(torch.int64),
|
||||
retrive_next_token=self.retrive_next_token.to(torch.int64),
|
||||
retrive_next_sibling=self.retrive_next_sibling.to(torch.int64),
|
||||
uniform_samples=coins,
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict["speculative_accept_threshold_acc"],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
def verify(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
page_size: int,
|
||||
vocab_mask: Optional[torch.Tensor] = None, # For grammar
|
||||
) -> torch.Tensor:
|
||||
bs = self.retrive_index.shape[0]
|
||||
sampling_info = batch.sampling_info
|
||||
|
||||
if bs != len(sampling_info):
|
||||
sampling_info = copy.deepcopy(sampling_info)
|
||||
# NOTE: retrive_index are the indices of the requests that are kept.
|
||||
sampling_info.filter_batch(self.retrive_index.tolist(), self.retrive_index)
|
||||
|
||||
# Apply the custom logit processors if registered in the sampling info.
|
||||
if sampling_info.has_custom_logit_processor:
|
||||
apply_custom_logit_processor(
|
||||
logits_output.next_token_logits,
|
||||
sampling_info,
|
||||
num_tokens_in_batch=self.draft_token_num,
|
||||
)
|
||||
|
||||
# Apply penalty
|
||||
if sampling_info.penalizer_orchestrator.is_required:
|
||||
# This is a relaxed version of penalties for speculative decoding.
|
||||
linear_penalty = torch.zeros(
|
||||
(bs, logits_output.next_token_logits.shape[1]),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
sampling_info.apply_logits_bias(linear_penalty)
|
||||
logits_output.next_token_logits.add_(
|
||||
torch.repeat_interleave(linear_penalty, self.draft_token_num, dim=0)
|
||||
)
|
||||
|
||||
# Apply grammar mask
|
||||
if vocab_mask is not None:
|
||||
assert self.grammar is not None
|
||||
self.grammar.apply_vocab_mask(
|
||||
logits=logits_output.next_token_logits, vocab_mask=vocab_mask
|
||||
)
|
||||
|
||||
# Sample tokens. Force greedy sampling on AMD
|
||||
is_all_greedy = sampling_info.is_all_greedy
|
||||
if (not is_all_greedy) and (not TREE_SPEC_KERNEL_AVAILABLE):
|
||||
logger.warning(
|
||||
"Tree speculative sampling kernel unavailable (likely AMD/HIP build). "
|
||||
"Falling back to greedy verification."
|
||||
)
|
||||
|
||||
if is_all_greedy or not TREE_SPEC_KERNEL_AVAILABLE:
|
||||
self._greedy_verify(batch, logits_output)
|
||||
else:
|
||||
# NOTE: Compared with greedy_verify, the performance of _sampling_verify is relatively poor.
|
||||
self._greedy_verify(batch, logits_output)
|
||||
# self._sampling_verify(batch, logits_output, sampling_info)
|
||||
|
||||
self._fill_requests(batch, logits_output)
|
||||
self._free_cache(batch, page_size)
|
||||
|
||||
batch.seq_lens.add_(self.accept_length + 1)
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
|
||||
return logits_output, self.verified_id, self.accept_length.sum().item()
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor):
|
||||
pass
|
||||
|
||||
def merge_batch(self, spec_info: LookaheadVerifyInput):
|
||||
pass
|
||||
244
python/sglang/srt/speculative/lookahead_worker.py
Normal file
244
python/sglang/srt/speculative/lookahead_worker.py
Normal file
@@ -0,0 +1,244 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.cpp_lookahead.lookahead_cache import LookaheadCache
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import broadcast_pyobj
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
USE_FULL_MASK = True
|
||||
|
||||
|
||||
class LOOKAHEADWorker:
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
moe_ep_rank: int,
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
self.target_worker = target_worker
|
||||
self.model_runner = target_worker.model_runner
|
||||
self.tp_rank = tp_rank
|
||||
self.page_size = server_args.page_size
|
||||
self.draft_token_num: int = server_args.speculative_num_draft_tokens
|
||||
self.branch_length: int = server_args.speculative_lookahead_branch_length
|
||||
self.max_match_window_size: int = (
|
||||
server_args.speculative_lookahead_max_match_window_size
|
||||
)
|
||||
|
||||
self.max_batch_size = target_worker.max_running_requests
|
||||
self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
|
||||
|
||||
self._init_preallocated_tensors()
|
||||
|
||||
self.lookahead_cache = LookaheadCache(
|
||||
min_match_window_size=server_args.speculative_lookahead_min_match_window_size,
|
||||
max_match_window_size=server_args.speculative_lookahead_max_match_window_size,
|
||||
min_bfs_breadth=server_args.speculative_lookahead_min_bfs_breadth,
|
||||
max_bfs_breadth=server_args.speculative_lookahead_max_bfs_breadth,
|
||||
capacity=server_args.speculative_lookahead_capacity,
|
||||
branch_length=server_args.speculative_lookahead_branch_length,
|
||||
draft_token_num=server_args.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
def clear_cache_pool(self):
|
||||
self.lookahead_cache.reset()
|
||||
|
||||
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
|
||||
seq2_len = len(seq2)
|
||||
if seq2_len >= n:
|
||||
return seq2[-n:]
|
||||
|
||||
need_from_seq1 = n - seq2_len
|
||||
return seq1[-need_from_seq1:] + seq2
|
||||
|
||||
def _init_preallocated_tensors(self):
|
||||
max_total_drafts = self.max_batch_size * self.draft_token_num
|
||||
max_total_mask_size = (
|
||||
self.max_batch_size * self.draft_token_num * self.draft_token_num
|
||||
)
|
||||
|
||||
self.draft_tokens = torch.empty(
|
||||
(max_total_drafts,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.retrieve_indexes = torch.empty(
|
||||
(self.max_batch_size, self.draft_token_num),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.retrive_next_token = torch.empty(
|
||||
(self.max_batch_size, self.draft_token_num),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.retrive_next_sibling = torch.empty(
|
||||
(self.max_batch_size, self.draft_token_num),
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
self.positions = torch.empty(
|
||||
(max_total_drafts,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
self.tree_mask = torch.empty(
|
||||
(max_total_mask_size,), dtype=torch.bool, device=self.device
|
||||
)
|
||||
|
||||
self.draft_tokens_batch = []
|
||||
self.tree_mask_batch = []
|
||||
self.retrieve_indexes_batch = []
|
||||
self.retrive_next_token_batch = []
|
||||
self.retrive_next_sibling_batch = []
|
||||
self.positions_batch = []
|
||||
|
||||
for bs in range(0, self.max_batch_size + 1):
|
||||
self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
|
||||
self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
|
||||
self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
|
||||
self.positions_batch.append(self.positions[: bs * self.draft_token_num])
|
||||
self.draft_tokens_batch.append(
|
||||
self.draft_tokens[: bs * self.draft_token_num]
|
||||
)
|
||||
self.tree_mask_batch.append(
|
||||
self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
|
||||
)
|
||||
|
||||
def _prepare_draft_tokens(
|
||||
self, batch: ScheduleBatch
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
bs = batch.batch_size()
|
||||
|
||||
self.lookahead_cache.synchronize()
|
||||
batch_tokens = []
|
||||
for req in batch.reqs:
|
||||
check_token = self._efficient_concat_last_n(
|
||||
req.origin_input_ids, req.output_ids, self.max_match_window_size
|
||||
)
|
||||
batch_tokens.append(check_token)
|
||||
req_drafts, mask = self.lookahead_cache.batch_get(batch_tokens)
|
||||
total_draft_token_num = len(req_drafts)
|
||||
|
||||
# Check if speculative decoding is needed; here we always enforce it
|
||||
assert (
|
||||
total_draft_token_num == bs * self.draft_token_num
|
||||
), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
|
||||
return req_drafts, mask
|
||||
|
||||
def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
|
||||
if batch.forward_mode.is_extend():
|
||||
return
|
||||
|
||||
bs = batch.batch_size()
|
||||
|
||||
retrive_index = self.retrieve_indexes_batch[bs]
|
||||
retrive_next_token = self.retrive_next_token_batch[bs]
|
||||
retrive_next_sibling = self.retrive_next_sibling_batch[bs]
|
||||
positions = self.positions_batch[bs]
|
||||
tree_mask = self.tree_mask_batch[bs]
|
||||
draft_tokens = self.draft_tokens_batch[bs]
|
||||
|
||||
req_drafts, mask = self._prepare_draft_tokens(batch)
|
||||
tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
|
||||
draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
|
||||
|
||||
reconstruct_indices_from_tree_mask(
|
||||
tree_mask,
|
||||
batch.seq_lens,
|
||||
positions, # mutable
|
||||
retrive_index, # mutable
|
||||
retrive_next_token, # mutable
|
||||
retrive_next_sibling, # mutable
|
||||
bs,
|
||||
self.draft_token_num,
|
||||
)
|
||||
|
||||
# NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
|
||||
# Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
|
||||
if USE_FULL_MASK:
|
||||
tree_mask = []
|
||||
mask = mask.reshape(
|
||||
batch.batch_size(), self.draft_token_num, self.draft_token_num
|
||||
)
|
||||
for i, req in enumerate(batch.reqs):
|
||||
seq_len = len(req.origin_input_ids) + len(req.output_ids)
|
||||
req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
|
||||
req_mask = torch.cat(
|
||||
(req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
|
||||
).to(torch.bool)
|
||||
tree_mask.append(req_mask.flatten())
|
||||
tree_mask = torch.cat(tree_mask, dim=0)
|
||||
|
||||
batch.spec_algorithm = SpeculativeAlgorithm.LOOKAHEAD
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.spec_info = LookaheadVerifyInput(
|
||||
draft_tokens,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
self.draft_token_num,
|
||||
)
|
||||
batch.spec_info.prepare_for_verify(batch, self.page_size)
|
||||
|
||||
def _update_lookahead_cache(self, batch: ScheduleBatch):
|
||||
batch_tokens = []
|
||||
for req in batch.reqs:
|
||||
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
|
||||
# there is not much difference, so we will not insert it for now.
|
||||
# if batch.forward_mode.is_extend():
|
||||
# put_ids = req.origin_input_ids + req.output_ids
|
||||
# else:
|
||||
put_ids = self._efficient_concat_last_n(
|
||||
req.origin_input_ids, req.output_ids, self.branch_length
|
||||
)
|
||||
batch_tokens.append(put_ids)
|
||||
self.lookahead_cache.batch_put(batch_tokens)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
self._prepare_for_speculative_decoding(batch)
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
bid = model_worker_batch.bid
|
||||
num_accepted_tokens = 0
|
||||
|
||||
if model_worker_batch.forward_mode.is_target_verify():
|
||||
logits_output, _, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
)
|
||||
verify_input = model_worker_batch.spec_info
|
||||
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
|
||||
batch, logits_output, self.page_size
|
||||
)
|
||||
self._update_lookahead_cache(batch)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
|
||||
else:
|
||||
logits_output, next_token_ids, can_run_cuda_graph = (
|
||||
self.target_worker.forward_batch_generation(model_worker_batch)
|
||||
)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
can_run_cuda_graph,
|
||||
)
|
||||
@@ -6,6 +6,7 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
EAGLE = auto()
|
||||
EAGLE3 = auto()
|
||||
STANDALONE = auto()
|
||||
LOOKAHEAD = auto()
|
||||
|
||||
def is_none(self):
|
||||
return self == SpeculativeAlgorithm.NONE
|
||||
@@ -19,12 +20,16 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
def is_standalone(self):
|
||||
return self == SpeculativeAlgorithm.STANDALONE
|
||||
|
||||
def is_lookahead(self):
|
||||
return self == SpeculativeAlgorithm.LOOKAHEAD
|
||||
|
||||
@staticmethod
|
||||
def from_string(name: str):
|
||||
name_map = {
|
||||
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
||||
"EAGLE3": SpeculativeAlgorithm.EAGLE3,
|
||||
"STANDALONE": SpeculativeAlgorithm.STANDALONE,
|
||||
"LOOKAHEAD": SpeculativeAlgorithm.LOOKAHEAD,
|
||||
None: SpeculativeAlgorithm.NONE,
|
||||
}
|
||||
if name is not None:
|
||||
|
||||
@@ -80,6 +80,7 @@ DEFAULT_STANDALONE_SPECULATIVE_TARGET_MODEL_FOR_TEST = (
|
||||
"meta-llama/Llama-3.1-8B-Instruct"
|
||||
)
|
||||
DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-Instruct"
|
||||
DEFAULT_LOOKAHEAD_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
|
||||
|
||||
# Other use cases
|
||||
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
|
||||
|
||||
Reference in New Issue
Block a user