Files
sglang/python/sglang/srt/speculative/ngram_worker.py
2025-10-08 11:49:46 +08:00

247 lines
9.5 KiB
Python

import logging
from typing import List, Optional
import numpy as np
import torch
from sgl_kernel.speculative import reconstruct_indices_from_tree_mask
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.managers.scheduler import GenerationBatchResult
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_info import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
logger = logging.getLogger(__name__)
USE_FULL_MASK = True
class NGRAMWorker:
def __init__(
self,
server_args: ServerArgs,
gpu_id: int,
tp_rank: int,
dp_rank: Optional[int],
moe_ep_rank: int,
nccl_port: int,
target_worker: TpModelWorker,
):
self.target_worker = target_worker
self.model_runner = target_worker.model_runner
self.tp_rank = tp_rank
self.page_size = server_args.page_size
self.draft_token_num: int = server_args.speculative_num_draft_tokens
self.branch_length: int = server_args.speculative_ngram_branch_length
self.max_match_window_size: int = (
server_args.speculative_ngram_max_match_window_size
)
self.max_batch_size = target_worker.max_running_requests
self.device = f"cuda:{gpu_id}" if gpu_id >= 0 else "cuda"
self._init_preallocated_tensors()
self.ngram_cache = NgramCache(
min_match_window_size=server_args.speculative_ngram_min_match_window_size,
max_match_window_size=server_args.speculative_ngram_max_match_window_size,
min_bfs_breadth=server_args.speculative_ngram_min_bfs_breadth,
max_bfs_breadth=server_args.speculative_ngram_max_bfs_breadth,
capacity=server_args.speculative_ngram_capacity,
branch_length=server_args.speculative_ngram_branch_length,
draft_token_num=server_args.speculative_num_draft_tokens,
)
def clear_cache_pool(self):
self.ngram_cache.reset()
def _efficient_concat_last_n(self, seq1: List[int], seq2: List[int], n: int):
seq2_len = len(seq2)
if seq2_len >= n:
return seq2[-n:]
need_from_seq1 = n - seq2_len
return seq1[-need_from_seq1:] + seq2
def _init_preallocated_tensors(self):
max_total_drafts = self.max_batch_size * self.draft_token_num
max_total_mask_size = (
self.max_batch_size * self.draft_token_num * self.draft_token_num
)
self.draft_tokens = torch.empty(
(max_total_drafts,), dtype=torch.int64, device=self.device
)
self.retrieve_indexes = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.retrive_next_token = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.retrive_next_sibling = torch.empty(
(self.max_batch_size, self.draft_token_num),
dtype=torch.int64,
device=self.device,
)
self.positions = torch.empty(
(max_total_drafts,), dtype=torch.int64, device=self.device
)
self.tree_mask = torch.empty(
(max_total_mask_size,), dtype=torch.bool, device=self.device
)
self.draft_tokens_batch = []
self.tree_mask_batch = []
self.retrieve_indexes_batch = []
self.retrive_next_token_batch = []
self.retrive_next_sibling_batch = []
self.positions_batch = []
for bs in range(0, self.max_batch_size + 1):
self.retrieve_indexes_batch.append(self.retrieve_indexes[:bs, :])
self.retrive_next_token_batch.append(self.retrive_next_token[:bs, :])
self.retrive_next_sibling_batch.append(self.retrive_next_sibling[:bs, :])
self.positions_batch.append(self.positions[: bs * self.draft_token_num])
self.draft_tokens_batch.append(
self.draft_tokens[: bs * self.draft_token_num]
)
self.tree_mask_batch.append(
self.tree_mask[: bs * self.draft_token_num * self.draft_token_num]
)
def _prepare_draft_tokens(
self, batch: ScheduleBatch
) -> tuple[np.ndarray, np.ndarray]:
bs = batch.batch_size()
self.ngram_cache.synchronize()
batch_tokens = []
for req in batch.reqs:
check_token = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.max_match_window_size
)
batch_tokens.append(check_token)
req_drafts, mask = self.ngram_cache.batch_get(batch_tokens)
total_draft_token_num = len(req_drafts)
# Check if speculative decoding is needed; here we always enforce it
assert (
total_draft_token_num == bs * self.draft_token_num
), f"{total_draft_token_num=}, {bs=}, {self.draft_token_num=}"
return req_drafts, mask
def _prepare_for_speculative_decoding(self, batch: ScheduleBatch):
if batch.forward_mode.is_extend():
return
bs = batch.batch_size()
retrive_index = self.retrieve_indexes_batch[bs]
retrive_next_token = self.retrive_next_token_batch[bs]
retrive_next_sibling = self.retrive_next_sibling_batch[bs]
positions = self.positions_batch[bs]
tree_mask = self.tree_mask_batch[bs]
draft_tokens = self.draft_tokens_batch[bs]
req_drafts, mask = self._prepare_draft_tokens(batch)
tree_mask.copy_(torch.from_numpy(mask), non_blocking=True)
draft_tokens.copy_(torch.from_numpy(req_drafts), non_blocking=True)
reconstruct_indices_from_tree_mask(
tree_mask,
batch.seq_lens,
positions, # mutable
retrive_index, # mutable
retrive_next_token, # mutable
retrive_next_sibling, # mutable
bs,
self.draft_token_num,
)
# NOTE: QLEN_MASK is faster than FULL_MASK, but requires corresponding changes in flashinfer.
# Testing shows about 8% performance improvement (the effect is roughly proportional to batch size).
if USE_FULL_MASK:
tree_mask = []
mask = mask.reshape(
batch.batch_size(), self.draft_token_num, self.draft_token_num
)
for i, req in enumerate(batch.reqs):
seq_len = len(req.origin_input_ids) + len(req.output_ids)
req_mask = torch.ones((self.draft_token_num, seq_len - 1)).cuda()
req_mask = torch.cat(
(req_mask, torch.from_numpy(mask[i]).cuda()), dim=1
).to(torch.bool)
tree_mask.append(req_mask.flatten())
tree_mask = torch.cat(tree_mask, dim=0)
batch.spec_algorithm = SpeculativeAlgorithm.NGRAM
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.spec_info = NgramVerifyInput(
draft_tokens,
tree_mask,
positions,
retrive_index,
retrive_next_token,
retrive_next_sibling,
self.draft_token_num,
)
batch.spec_info.prepare_for_verify(batch, self.page_size)
def _update_ngram_cache(self, batch: ScheduleBatch):
batch_tokens = []
for req in batch.reqs:
# FIXME: Whether to insert 'extend' into the cache or not, after testing,
# there is not much difference, so we will not insert it for now.
# if batch.forward_mode.is_extend():
# put_ids = req.origin_input_ids + req.output_ids
# else:
put_ids = self._efficient_concat_last_n(
req.origin_input_ids, req.output_ids, self.branch_length
)
batch_tokens.append(put_ids)
self.ngram_cache.batch_put(batch_tokens)
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
self._prepare_for_speculative_decoding(batch)
model_worker_batch = batch.get_model_worker_batch()
num_accepted_tokens = 0
if model_worker_batch.forward_mode.is_target_verify():
batch_result = self.target_worker.forward_batch_generation(
model_worker_batch, is_verify=True
)
logits_output, can_run_cuda_graph = (
batch_result.logits_output,
batch_result.can_run_cuda_graph,
)
verify_input = model_worker_batch.spec_info
logits_output, next_token_ids, num_accepted_tokens = verify_input.verify(
batch, logits_output, self.page_size
)
self._update_ngram_cache(batch)
batch.forward_mode = ForwardMode.DECODE
else:
batch_result = self.target_worker.forward_batch_generation(
model_worker_batch
)
logits_output, next_token_ids, can_run_cuda_graph = (
batch_result.logits_output,
batch_result.next_token_ids,
batch_result.can_run_cuda_graph,
)
return GenerationBatchResult(
logits_output=logits_output,
next_token_ids=next_token_ids,
num_accepted_tokens=num_accepted_tokens,
can_run_cuda_graph=can_run_cuda_graph,
)