2025-02-03 20:52:30 +08:00
import logging
2025-03-03 10:58:45 +08:00
import os
2025-02-03 20:52:30 +08:00
import time
2025-03-16 02:48:55 -07:00
from contextlib import contextmanager
2025-03-06 06:13:59 -08:00
from typing import List , Optional , Tuple
2025-01-02 19:22:34 +08:00
import torch
2025-03-03 10:58:45 +08:00
from huggingface_hub import snapshot_download
2025-01-02 19:22:34 +08:00
2025-06-17 15:33:28 +08:00
from sglang . srt . distributed import (
GroupCoordinator ,
get_tp_group ,
patch_tensor_parallel_group ,
)
2025-01-02 19:22:34 +08:00
from sglang . srt . layers . logits_processor import LogitsProcessorOutput
2025-03-07 22:12:13 -08:00
from sglang . srt . layers . sampler import get_token_ids_logprobs , get_top_logprobs
2025-10-13 19:34:43 +08:00
from sglang . srt . managers . schedule_batch import ScheduleBatch
2025-10-07 20:12:12 +08:00
from sglang . srt . managers . scheduler import GenerationBatchResult
2025-01-02 19:22:34 +08:00
from sglang . srt . managers . tp_worker import TpModelWorker
2025-10-10 17:38:54 -07:00
from sglang . srt . mem_cache . common import (
alloc_paged_token_slots_extend ,
alloc_token_slots ,
get_last_loc ,
)
2025-01-02 19:22:34 +08:00
from sglang . srt . model_executor . forward_batch_info import (
CaptureHiddenMode ,
ForwardBatch ,
ForwardMode ,
)
2025-10-14 23:52:04 +08:00
from sglang . srt . server_args import ServerArgs
from sglang . srt . speculative . draft_utils import DraftBackendFactory
2025-02-03 20:52:30 +08:00
from sglang . srt . speculative . eagle_draft_cuda_graph_runner import (
EAGLEDraftCudaGraphRunner ,
)
2025-05-27 17:35:17 +08:00
from sglang . srt . speculative . eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner ,
)
2025-10-01 09:45:30 +08:00
from sglang . srt . speculative . eagle_info import (
2025-02-03 20:52:30 +08:00
EagleDraftInput ,
EagleVerifyInput ,
2025-03-05 08:06:07 -08:00
EagleVerifyOutput ,
2025-10-01 09:45:30 +08:00
)
2025-10-14 16:50:53 +02:00
from sglang . srt . speculative . eagle_utils import (
build_tree_kernel_efficient ,
organize_draft_results ,
)
2025-10-01 09:45:30 +08:00
from sglang . srt . speculative . spec_info import SpeculativeAlgorithm
from sglang . srt . speculative . spec_utils import (
2025-02-03 20:52:30 +08:00
assign_draft_cache_locs ,
2025-06-16 03:04:29 -07:00
fast_topk ,
2025-05-22 08:18:41 +08:00
generate_token_bitmask ,
2025-02-03 20:52:30 +08:00
select_top_k_tokens ,
)
2025-06-16 03:04:29 -07:00
from sglang . srt . utils import (
empty_context ,
get_available_gpu_memory ,
2025-08-29 11:43:57 -07:00
get_bool_env_var ,
2025-09-13 17:18:26 +08:00
is_blackwell ,
2025-06-16 03:04:29 -07:00
is_cuda ,
next_power_of_2 ,
)
2025-03-16 02:48:55 -07:00
2025-04-21 09:16:51 +08:00
if is_cuda ( ) :
2025-03-16 02:48:55 -07:00
from sgl_kernel import segment_packbits
2025-02-03 20:52:30 +08:00
logger = logging . getLogger ( __name__ )
2025-08-29 11:43:57 -07:00
RETURN_ORIGINAL_LOGPROB = get_bool_env_var ( " RETURN_ORIGINAL_LOGPROB " )
2025-01-02 19:22:34 +08:00
2025-03-16 02:48:55 -07:00
@contextmanager
def draft_tp_context ( tp_group : GroupCoordinator ) :
# Draft model doesn't use dp and has its own tp group.
# We disable mscclpp now because it doesn't support 2 comm groups.
2025-06-17 15:33:28 +08:00
with patch_tensor_parallel_group ( tp_group ) :
2025-03-16 02:48:55 -07:00
yield
2025-01-02 19:22:34 +08:00
class EAGLEWorker ( TpModelWorker ) :
def __init__ (
self ,
server_args : ServerArgs ,
gpu_id : int ,
tp_rank : int ,
dp_rank : Optional [ int ] ,
2025-07-31 02:53:25 -07:00
moe_ep_rank : int ,
2025-01-02 19:22:34 +08:00
nccl_port : int ,
target_worker : TpModelWorker ,
) :
2025-03-06 06:13:59 -08:00
# Parse arguments
self . server_args = server_args
self . topk = server_args . speculative_eagle_topk
self . speculative_num_steps = server_args . speculative_num_steps
2025-06-17 15:33:28 +08:00
self . speculative_num_draft_tokens = server_args . speculative_num_draft_tokens
2025-03-06 06:13:59 -08:00
self . enable_nan_detection = server_args . enable_nan_detection
self . gpu_id = gpu_id
self . device = server_args . device
self . target_worker = target_worker
2025-03-30 00:46:23 -07:00
self . page_size = server_args . page_size
2025-03-18 10:35:23 -04:00
self . speculative_algorithm = SpeculativeAlgorithm . from_string (
server_args . speculative_algorithm
)
2025-06-09 16:39:21 -07:00
self . padded_static_len = - 1
2025-03-06 06:13:59 -08:00
2025-08-21 10:32:34 +08:00
# Override the context length of the draft model to be the same as the target model.
2025-03-05 08:06:07 -08:00
server_args . context_length = target_worker . model_runner . model_config . context_len
2025-05-12 12:53:26 -07:00
# Do not capture cuda graph in `super().__init__()`
2025-03-06 06:13:59 -08:00
# It will be captured later.
2025-01-02 19:22:34 +08:00
backup_disable_cuda_graph = server_args . disable_cuda_graph
server_args . disable_cuda_graph = True
2025-03-06 06:13:59 -08:00
# Share the allocator with a target worker.
# Draft and target worker own their own KV cache pools.
self . req_to_token_pool , self . token_to_kv_pool_allocator = (
target_worker . get_memory_pool ( )
)
2025-03-03 10:58:45 +08:00
2025-03-06 06:13:59 -08:00
# Load hot token ids
2025-03-18 10:35:23 -04:00
if self . speculative_algorithm . is_eagle3 ( ) :
if server_args . speculative_token_map is not None :
logger . warning (
" Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map. "
)
self . hot_token_id = None
elif server_args . speculative_token_map is not None :
2025-03-04 20:26:24 +08:00
self . hot_token_id = load_token_map ( server_args . speculative_token_map )
2025-03-03 10:58:45 +08:00
server_args . json_model_override_args = (
f ' {{ " hot_vocab_size " : { len ( self . hot_token_id ) } }} '
)
2025-03-04 20:26:24 +08:00
else :
self . hot_token_id = None
2025-03-03 10:58:45 +08:00
2025-03-06 06:13:59 -08:00
# Init draft worker
2025-03-16 02:48:55 -07:00
with empty_context ( ) :
super ( ) . __init__ (
2025-04-30 18:18:07 -07:00
server_args = server_args ,
2025-03-16 02:48:55 -07:00
gpu_id = gpu_id ,
tp_rank = tp_rank ,
2025-04-30 18:18:07 -07:00
pp_rank = 0 , # FIXME
2025-03-16 02:48:55 -07:00
dp_rank = dp_rank ,
2025-07-31 02:53:25 -07:00
moe_ep_rank = moe_ep_rank ,
2025-04-30 18:18:07 -07:00
nccl_port = nccl_port ,
2025-03-16 02:48:55 -07:00
is_draft_worker = True ,
req_to_token_pool = self . req_to_token_pool ,
token_to_kv_pool_allocator = self . token_to_kv_pool_allocator ,
)
2025-02-03 20:52:30 +08:00
2025-03-04 05:30:04 +08:00
embed , head = self . target_worker . model_runner . model . get_embed_and_head ( )
2025-03-18 10:35:23 -04:00
if self . speculative_algorithm . is_eagle3 ( ) :
2025-08-28 15:20:20 +08:00
# most cases EAGLE3 models don't share lm_head
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
if (
hasattr ( self . draft_model_runner . model , " load_lm_head_from_target " )
and self . draft_model_runner . model . load_lm_head_from_target
) :
self . draft_model_runner . model . set_embed_and_head ( embed , head )
else :
self . draft_model_runner . model . set_embed ( embed )
2025-03-18 10:35:23 -04:00
# grab hot token ids
2025-07-01 13:34:10 +08:00
if self . draft_model_runner . model . hot_token_id is not None :
self . hot_token_id = self . draft_model_runner . model . hot_token_id . to (
embed . device
)
2025-03-18 10:35:23 -04:00
else :
if self . hot_token_id is not None :
head = head . clone ( )
self . hot_token_id = self . hot_token_id . to ( head . device )
head . data = head . data [ self . hot_token_id ]
# Share the embedding and lm_head
self . draft_model_runner . model . set_embed_and_head ( embed , head )
2025-03-16 02:48:55 -07:00
2025-05-12 12:53:26 -07:00
# Init attention backend and cuda graphs
2025-03-05 08:06:07 -08:00
self . draft_model_runner . server_args . disable_cuda_graph = (
backup_disable_cuda_graph
)
2025-03-16 02:48:55 -07:00
self . draft_tp_context = (
draft_tp_context if server_args . enable_dp_attention else empty_context
)
with self . draft_tp_context ( self . draft_model_runner . tp_group ) :
self . init_attention_backend ( )
self . init_cuda_graphs ( )
2025-03-06 06:13:59 -08:00
2025-06-16 03:04:29 -07:00
# Some dummy tensors
self . num_new_pages_per_topk = torch . empty (
( ) , dtype = torch . int64 , device = self . device
)
self . extend_lens = torch . empty ( ( ) , dtype = torch . int64 , device = self . device )
2025-03-06 06:13:59 -08:00
def init_attention_backend ( self ) :
2025-05-12 12:53:26 -07:00
# Create multi-step attn backends and cuda graph runners
2025-06-16 15:57:07 -07:00
self . has_prefill_wrapper_verify = False
self . draft_extend_attn_backend = None
2025-10-14 23:52:04 +08:00
draft_backend_factory = DraftBackendFactory (
self . server_args ,
self . draft_model_runner ,
self . topk ,
self . speculative_num_steps ,
2025-09-03 13:47:23 +08:00
)
2025-10-14 23:52:04 +08:00
# Initialize decode attention backend
self . draft_attn_backend = draft_backend_factory . create_decode_backend ( )
2025-09-03 13:47:23 +08:00
2025-10-14 23:52:04 +08:00
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
self . draft_extend_attn_backend = (
draft_backend_factory . create_draft_extend_backend ( )
2025-09-03 13:47:23 +08:00
)
2025-10-14 23:52:04 +08:00
self . draft_model_runner . draft_attn_backend = self . draft_attn_backend
2025-10-04 20:59:34 +08:00
2025-02-03 20:52:30 +08:00
def init_cuda_graphs ( self ) :
2025-05-12 12:53:26 -07:00
""" Capture cuda graphs. """
2025-02-03 20:52:30 +08:00
self . cuda_graph_runner = None
2025-03-16 02:48:55 -07:00
self . cuda_graph_runner_for_draft_extend = None
2025-02-03 20:52:30 +08:00
if self . server_args . disable_cuda_graph :
return
2025-03-16 02:48:55 -07:00
# Capture draft
2025-05-17 16:49:18 -07:00
tic = time . perf_counter ( )
2025-03-16 02:48:55 -07:00
before_mem = get_available_gpu_memory ( self . device , self . gpu_id )
2025-03-05 08:06:07 -08:00
logger . info (
2025-05-12 12:53:26 -07:00
f " Capture draft cuda graph begin. This can take up to several minutes. avail mem= { before_mem : .2f } GB "
2025-03-05 08:06:07 -08:00
)
2025-02-03 20:52:30 +08:00
self . cuda_graph_runner = EAGLEDraftCudaGraphRunner ( self )
2025-03-16 02:48:55 -07:00
after_mem = get_available_gpu_memory ( self . device , self . gpu_id )
2025-03-05 08:06:07 -08:00
logger . info (
2025-06-16 03:04:29 -07:00
f " Capture draft cuda graph end. Time elapsed: { time . perf_counter ( ) - tic : .2f } s. mem usage= { ( before_mem - after_mem ) : .2f } GB. avail mem= { after_mem : .2f } GB. "
2025-03-05 08:06:07 -08:00
)
2025-01-02 19:22:34 +08:00
2025-03-16 02:48:55 -07:00
# Capture extend
if self . draft_extend_attn_backend :
2025-05-27 17:35:17 +08:00
tic = time . perf_counter ( )
before_mem = get_available_gpu_memory ( self . device , self . gpu_id )
logger . info (
f " Capture draft extend cuda graph begin. This can take up to several minutes. avail mem= { before_mem : .2f } GB "
)
self . cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner (
self
)
after_mem = get_available_gpu_memory ( self . device , self . gpu_id )
logger . info (
2025-06-16 03:04:29 -07:00
f " Capture draft extend cuda graph end. Time elapsed: { time . perf_counter ( ) - tic : .2f } s. mem usage= { ( before_mem - after_mem ) : .2f } GB. avail mem= { after_mem : .2f } GB. "
2025-05-27 17:35:17 +08:00
)
2025-03-16 02:48:55 -07:00
2025-03-05 08:06:07 -08:00
@property
def draft_model_runner ( self ) :
return self . model_runner
2025-10-07 20:12:12 +08:00
def forward_batch_generation ( self , batch : ScheduleBatch ) - > GenerationBatchResult :
2025-03-05 08:06:07 -08:00
""" Run speculative decoding forward.
2025-03-16 02:48:55 -07:00
NOTE : Many states of batch is modified as you go through . It is not guaranteed that
the final output batch have the same state as the input .
2025-03-05 08:06:07 -08:00
Args :
batch : The batch to run forward . The state of the batch is modified as it runs .
Returns :
2025-05-11 00:55:00 -04:00
A tuple of the final logit output of the target model , next tokens accepted ,
the batch id ( used for overlap schedule ) , and number of accepted tokens .
2025-03-05 08:06:07 -08:00
"""
2025-06-17 15:33:28 +08:00
if batch . forward_mode . is_extend ( ) or batch . is_extend_in_batch :
2025-10-03 00:28:57 +08:00
logits_output , next_token_ids , seq_lens_cpu = self . forward_target_extend (
batch
2025-06-17 15:33:28 +08:00
)
with self . draft_tp_context ( self . draft_model_runner . tp_group ) :
self . forward_draft_extend (
batch , logits_output . hidden_states , next_token_ids , seq_lens_cpu
)
2025-10-07 20:12:12 +08:00
return GenerationBatchResult (
2025-10-03 00:28:57 +08:00
logits_output = logits_output ,
next_token_ids = next_token_ids ,
num_accepted_tokens = 0 ,
can_run_cuda_graph = False ,
)
2025-06-17 15:33:28 +08:00
else :
2025-03-16 02:48:55 -07:00
with self . draft_tp_context ( self . draft_model_runner . tp_group ) :
2025-03-30 00:46:23 -07:00
spec_info = self . draft ( batch )
2025-05-12 00:17:33 -07:00
logits_output , verify_output , model_worker_batch , can_run_cuda_graph = (
self . verify ( batch , spec_info )
2025-03-05 08:06:07 -08:00
)
2025-06-24 08:34:13 +08:00
2025-07-24 21:36:21 -07:00
with self . draft_tp_context ( self . draft_model_runner . tp_group ) :
# NOTE: We should use `check_forward_draft_extend_after_decode`
# when DP attention is enabled, but it is slow. Skip it for now.
if (
self . server_args . enable_dp_attention
or batch . spec_info . verified_id . shape [ 0 ] > 0
) :
# decode is not finished
self . forward_draft_extend_after_decode ( batch )
2025-10-07 20:12:12 +08:00
return GenerationBatchResult (
2025-10-03 00:28:57 +08:00
logits_output = logits_output ,
next_token_ids = verify_output . verified_id ,
num_accepted_tokens = sum ( verify_output . accept_length_per_req_cpu ) ,
can_run_cuda_graph = can_run_cuda_graph ,
2025-01-06 14:54:18 -08:00
)
2025-04-18 09:49:59 +08:00
2025-06-17 15:33:28 +08:00
def check_forward_draft_extend_after_decode ( self , batch : ScheduleBatch ) :
2025-07-24 21:36:21 -07:00
local_need_forward = batch . spec_info . verified_id . shape [ 0 ] > 0
2025-06-17 15:33:28 +08:00
if not self . server_args . enable_dp_attention :
2025-06-24 08:34:13 +08:00
return local_need_forward
2025-06-17 15:33:28 +08:00
global_need_forward = torch . tensor (
[
( local_need_forward ) ,
] ,
dtype = torch . int64 ,
)
torch . distributed . all_reduce (
global_need_forward , group = get_tp_group ( ) . cpu_group
)
global_need_forward_cnt = global_need_forward [ 0 ] . item ( )
need_forward = global_need_forward_cnt > 0
2025-06-24 08:34:13 +08:00
return need_forward
2025-03-05 08:06:07 -08:00
def forward_target_extend (
self , batch : ScheduleBatch
2025-07-24 21:36:21 -07:00
) - > Tuple [ LogitsProcessorOutput , torch . Tensor , int , Optional [ torch . Tensor ] ] :
2025-03-05 08:06:07 -08:00
""" Run the target extend.
Args :
batch : The batch to run . States could be modified .
Returns :
logits_output : The output of logits . It will contain the full hidden states .
next_token_ids : Next token ids generated .
"""
# Forward with the target model and get hidden states.
# We need the full hidden states to prefill the KV cache of the draft model.
model_worker_batch = batch . get_model_worker_batch ( )
model_worker_batch . capture_hidden_mode = CaptureHiddenMode . FULL
2025-10-07 20:12:12 +08:00
batch_result = self . target_worker . forward_batch_generation ( model_worker_batch )
2025-10-03 00:28:57 +08:00
logits_output , next_token_ids = (
2025-10-07 20:12:12 +08:00
batch_result . logits_output ,
batch_result . next_token_ids ,
2025-10-03 00:28:57 +08:00
)
2025-06-09 16:39:21 -07:00
return (
logits_output ,
next_token_ids ,
model_worker_batch . seq_lens_cpu ,
)
2025-01-02 19:22:34 +08:00
2025-06-17 15:33:28 +08:00
def _draft_preprocess_decode ( self , batch : ScheduleBatch ) :
2025-02-03 20:52:30 +08:00
# Parse args
num_seqs = batch . batch_size ( )
spec_info = batch . spec_info
2025-03-16 02:48:55 -07:00
# Accumulate penalty
if batch . sampling_info . penalizer_orchestrator . is_required :
# This is a relaxed version of penalties for speculative decoding.
batch . sampling_info . penalizer_orchestrator . cumulate_output_tokens (
spec_info . verified_id . to ( torch . int64 )
)
2025-02-03 20:52:30 +08:00
# Allocate cache locations
2025-06-16 03:04:29 -07:00
# Layout of the out_cache_loc
# [ topk 0 ] [ topk 1 ]
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
2025-03-30 00:46:23 -07:00
if self . page_size == 1 :
2025-10-10 17:38:54 -07:00
out_cache_loc , token_to_kv_pool_state_backup = alloc_token_slots (
batch . tree_cache ,
num_seqs * self . speculative_num_steps * self . topk ,
backup_state = True ,
2025-03-30 00:46:23 -07:00
)
else :
if self . topk == 1 :
2025-06-16 03:04:29 -07:00
prefix_lens , seq_lens , last_loc = get_last_loc_large_page_size_top_k_1 (
batch . req_to_token_pool . req_to_token ,
batch . req_pool_indices ,
batch . seq_lens ,
self . speculative_num_steps ,
)
2025-10-02 10:51:25 +08:00
prefix_lens_cpu = batch . seq_lens_cpu
seq_lens_cpu = batch . seq_lens_cpu + self . speculative_num_steps
2025-03-30 00:46:23 -07:00
extend_num_tokens = num_seqs * self . speculative_num_steps
else :
# In this case, the last partial page needs to be duplicated.
# KV cache layout in batch.req_to_token_pool.req_to_token:
#
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
# prefix top-k = 0 tok-k = 1 top-k = 2
#
# "-" means prefix tokens
# "x" means speculative draft tokens
# "." means padded tokens
2025-06-16 03:04:29 -07:00
# TODO(lmzheng): The current implementation is still a fake support
# for page size > 1. In the `assign_draft_cache_locs` below,
# we directly move the indices instead of the real kv cache.
# This only works when the kernel backend runs with page size = 1.
# If the kernel backend runs with page size > 1, we need to
# duplicate the real KV cache. The overhead of duplicating KV
# cache seems okay because the draft KV cache only has one layer.
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
(
prefix_lens ,
seq_lens ,
last_loc ,
self . num_new_pages_per_topk ,
self . extend_lens ,
) = get_last_loc_large_page_size_large_top_k (
batch . req_to_token_pool . req_to_token ,
batch . req_pool_indices ,
batch . seq_lens ,
self . speculative_num_steps ,
self . topk ,
self . page_size ,
2025-06-15 02:48:00 -07:00
)
2025-10-02 10:51:25 +08:00
prefix_lens_cpu = batch . seq_lens_cpu
last_page_lens = prefix_lens_cpu % self . page_size
num_new_pages_per_topk = (
last_page_lens + self . speculative_num_steps + self . page_size - 1
) / / self . page_size
seq_lens_cpu = (
prefix_lens_cpu / / self . page_size * self . page_size
+ num_new_pages_per_topk * ( self . page_size * self . topk )
)
extend_num_tokens = torch . sum ( ( seq_lens_cpu - prefix_lens_cpu ) ) . item ( )
2025-06-16 03:04:29 -07:00
2025-03-30 00:46:23 -07:00
out_cache_loc , token_to_kv_pool_state_backup = (
2025-10-10 17:38:54 -07:00
alloc_paged_token_slots_extend (
batch . tree_cache ,
2025-03-30 00:46:23 -07:00
prefix_lens ,
2025-10-02 10:51:25 +08:00
prefix_lens_cpu ,
2025-03-30 00:46:23 -07:00
seq_lens ,
2025-10-02 10:51:25 +08:00
seq_lens_cpu ,
2025-03-30 00:46:23 -07:00
last_loc ,
extend_num_tokens ,
backup_state = True ,
)
)
2025-02-03 20:52:30 +08:00
assign_draft_cache_locs [ ( num_seqs , ) ] (
batch . req_pool_indices ,
batch . req_to_token_pool . req_to_token ,
batch . seq_lens ,
2025-06-16 03:04:29 -07:00
self . extend_lens ,
self . num_new_pages_per_topk ,
2025-02-03 20:52:30 +08:00
out_cache_loc ,
batch . req_to_token_pool . req_to_token . shape [ 1 ] ,
self . topk ,
self . speculative_num_steps ,
2025-03-30 00:46:23 -07:00
self . page_size ,
2025-06-16 03:04:29 -07:00
next_power_of_2 ( num_seqs ) ,
next_power_of_2 ( self . speculative_num_steps ) ,
2025-02-03 20:52:30 +08:00
)
2025-06-16 03:04:29 -07:00
if self . page_size > 1 and self . topk > 1 :
# Remove padded slots
out_cache_loc = out_cache_loc [
: num_seqs * self . topk * self . speculative_num_steps
]
2025-02-03 20:52:30 +08:00
batch . out_cache_loc = out_cache_loc
batch . seq_lens_sum = torch . sum ( batch . seq_lens ) . item ( )
2025-06-16 03:04:29 -07:00
batch . return_hidden_states = False
2025-02-03 20:52:30 +08:00
spec_info . positions = batch . seq_lens . repeat_interleave ( self . topk , dim = 0 )
2025-06-17 15:33:28 +08:00
self . token_to_kv_pool_allocator . restore_state ( token_to_kv_pool_state_backup )
def _draft_preprocess_idle ( self , batch : ScheduleBatch ) :
batch . spec_info = EagleDraftInput . create_idle_input (
device = self . device ,
hidden_size = self . model_config . hidden_size ,
2025-06-23 11:23:25 -07:00
dtype = self . model_config . dtype ,
2025-06-17 15:33:28 +08:00
topk = self . topk ,
capture_hidden_mode = CaptureHiddenMode . LAST ,
)
def draft ( self , batch : ScheduleBatch ) :
# Parse args
if batch . forward_mode . is_idle ( ) :
self . _draft_preprocess_idle ( batch )
else :
self . _draft_preprocess_decode ( batch )
spec_info = batch . spec_info
2025-07-24 21:36:21 -07:00
assert isinstance ( spec_info , EagleDraftInput )
2025-06-17 15:33:28 +08:00
2025-02-03 20:52:30 +08:00
spec_info . capture_hidden_mode = CaptureHiddenMode . LAST
2025-07-24 21:36:21 -07:00
spec_info . num_tokens_per_batch = self . topk
spec_info . num_tokens_for_logprob_per_batch = self . topk
2025-06-17 15:33:28 +08:00
batch . return_hidden_states = False
2025-06-16 03:04:29 -07:00
# Get forward batch
2025-02-03 20:52:30 +08:00
model_worker_batch = batch . get_model_worker_batch ( )
2025-06-17 15:33:28 +08:00
assert model_worker_batch . capture_hidden_mode == CaptureHiddenMode . LAST
2025-03-05 08:06:07 -08:00
forward_batch = ForwardBatch . init_new (
model_worker_batch , self . draft_model_runner
)
2025-02-03 20:52:30 +08:00
can_cuda_graph = self . cuda_graph_runner and self . cuda_graph_runner . can_run (
forward_batch
)
if can_cuda_graph :
2025-10-14 16:50:53 +02:00
parent_list , top_scores_index , draft_tokens = self . cuda_graph_runner . replay (
2025-02-03 20:52:30 +08:00
forward_batch
)
else :
2025-07-24 21:36:21 -07:00
forward_batch . can_run_dp_cuda_graph = False
2025-06-17 15:33:28 +08:00
if not forward_batch . forward_mode . is_idle ( ) :
# Initialize attention backend
self . draft_attn_backend . init_forward_metadata ( forward_batch )
2025-02-03 20:52:30 +08:00
# Run forward steps
2025-10-14 16:50:53 +02:00
parent_list , top_scores_index , draft_tokens = self . draft_forward (
forward_batch
)
2025-02-03 20:52:30 +08:00
2025-06-17 15:33:28 +08:00
if batch . forward_mode . is_idle ( ) :
return EagleVerifyInput . create_idle_input (
self . topk ,
self . speculative_num_steps ,
self . speculative_num_draft_tokens ,
)
2025-03-30 00:46:23 -07:00
2025-06-09 16:39:21 -07:00
(
tree_mask ,
position ,
retrive_index ,
retrive_next_token ,
retrive_next_sibling ,
draft_tokens ,
) = build_tree_kernel_efficient (
2025-02-03 20:52:30 +08:00
spec_info . verified_id ,
2025-10-14 16:50:53 +02:00
parent_list ,
top_scores_index ,
draft_tokens ,
2025-02-03 20:52:30 +08:00
batch . seq_lens ,
batch . seq_lens_sum ,
self . topk ,
self . speculative_num_steps ,
2025-06-17 15:33:28 +08:00
self . speculative_num_draft_tokens ,
2025-02-03 20:52:30 +08:00
)
2025-06-09 16:39:21 -07:00
return EagleVerifyInput (
draft_token = draft_tokens ,
custom_mask = tree_mask ,
positions = position ,
retrive_index = retrive_index ,
retrive_next_token = retrive_next_token ,
retrive_next_sibling = retrive_next_sibling ,
retrive_cum_len = None ,
spec_steps = self . speculative_num_steps ,
topk = self . topk ,
draft_token_num = self . server_args . speculative_num_draft_tokens ,
capture_hidden_mode = CaptureHiddenMode . FULL ,
seq_lens_sum = forward_batch . seq_lens_sum ,
seq_lens_cpu = forward_batch . seq_lens_cpu ,
)
2025-02-03 20:52:30 +08:00
def draft_forward ( self , forward_batch : ForwardBatch ) :
# Parse args
spec_info = forward_batch . spec_info
2025-07-24 21:36:21 -07:00
assert isinstance ( spec_info , EagleDraftInput )
2025-02-03 20:52:30 +08:00
out_cache_loc = forward_batch . out_cache_loc
topk_p , topk_index , hidden_states = (
spec_info . topk_p ,
spec_info . topk_index ,
spec_info . hidden_states ,
)
2025-03-03 10:58:45 +08:00
if self . hot_token_id is not None :
topk_index = self . hot_token_id [ topk_index ]
2025-02-03 20:52:30 +08:00
2025-06-16 03:04:29 -07:00
out_cache_loc = out_cache_loc . reshape (
forward_batch . batch_size , self . topk , self . speculative_num_steps
)
out_cache_loc = out_cache_loc . permute ( ( 2 , 0 , 1 ) ) . reshape (
self . speculative_num_steps , - 1
)
2025-02-03 20:52:30 +08:00
# Return values
score_list : List [ torch . Tensor ] = [ ]
token_list : List [ torch . Tensor ] = [ ]
parents_list : List [ torch . Tensor ] = [ ]
# Forward multiple steps
scores = None
for i in range ( self . speculative_num_steps ) :
input_ids , hidden_states , scores , tree_info = select_top_k_tokens (
i , topk_p , topk_index , hidden_states , scores , self . topk
)
score_list . append ( tree_info [ 0 ] )
token_list . append ( tree_info [ 1 ] )
parents_list . append ( tree_info [ 2 ] )
2025-03-16 02:48:55 -07:00
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
2025-02-10 04:21:49 -08:00
if i == self . speculative_num_steps - 1 :
break
2025-02-03 20:52:30 +08:00
# Set inputs
forward_batch . input_ids = input_ids
2025-09-07 20:55:09 -07:00
# This is a temporary fix for the case that the user is using standalone
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
# rope kernel needs cache_loc to be contiguous.
if (
self . server_args . speculative_algorithm == " STANDALONE "
and self . model_config . hf_config . architectures [ 0 ] == " GptOssForCausalLM "
) :
out_cache_loc = out_cache_loc . contiguous ( )
2025-06-16 03:04:29 -07:00
forward_batch . out_cache_loc = out_cache_loc [ i ]
2025-02-03 20:52:30 +08:00
forward_batch . positions . add_ ( 1 )
forward_batch . attn_backend = self . draft_attn_backend . attn_backends [ i ]
spec_info . hidden_states = hidden_states
# Run forward
2025-07-24 21:36:21 -07:00
logits_output , _ = self . draft_model_runner . forward (
forward_batch , skip_attn_backend_init = True
2025-02-03 20:52:30 +08:00
)
2025-03-05 08:06:07 -08:00
self . _detect_nan_if_needed ( logits_output )
2025-02-03 20:52:30 +08:00
probs = torch . softmax ( logits_output . next_token_logits , dim = - 1 )
topk_p , topk_index = fast_topk ( probs , self . topk , dim = - 1 )
2025-03-03 10:58:45 +08:00
if self . hot_token_id is not None :
topk_index = self . hot_token_id [ topk_index ]
2025-02-03 20:52:30 +08:00
hidden_states = logits_output . hidden_states
2025-10-14 16:50:53 +02:00
parent_list , top_scores_index , draft_tokens = organize_draft_results (
score_list , token_list , parents_list , self . speculative_num_draft_tokens
)
return parent_list , top_scores_index , draft_tokens
2025-02-03 20:52:30 +08:00
2025-09-19 07:42:41 +08:00
def clear_cache_pool ( self ) :
self . model_runner . req_to_token_pool . clear ( )
self . model_runner . token_to_kv_pool_allocator . clear ( )
2025-02-03 20:52:30 +08:00
def verify ( self , batch : ScheduleBatch , spec_info : EagleVerifyInput ) :
2025-03-30 00:46:23 -07:00
spec_info . prepare_for_verify ( batch , self . page_size )
2025-06-10 17:37:29 -04:00
batch . return_hidden_states = False
2025-06-17 15:33:28 +08:00
batch . forward_mode = (
ForwardMode . TARGET_VERIFY
if not batch . forward_mode . is_idle ( )
else ForwardMode . IDLE
)
2025-02-03 20:52:30 +08:00
batch . spec_info = spec_info
2025-07-24 21:36:21 -07:00
2025-06-09 16:39:21 -07:00
model_worker_batch = batch . get_model_worker_batch (
seq_lens_cpu_cache = spec_info . seq_lens_cpu
)
2025-06-10 17:37:29 -04:00
assert model_worker_batch . capture_hidden_mode == spec_info . capture_hidden_mode
2025-05-22 08:18:41 +08:00
if batch . has_grammar :
retrieve_next_token_cpu = spec_info . retrive_next_token . cpu ( )
retrieve_next_sibling_cpu = spec_info . retrive_next_sibling . cpu ( )
draft_tokens_cpu = spec_info . draft_token . view (
spec_info . retrive_next_token . shape
) . cpu ( )
# Forward
2025-10-07 20:12:12 +08:00
batch_result = self . target_worker . forward_batch_generation (
2025-10-05 23:41:04 +08:00
model_worker_batch , is_verify = True
2025-10-03 00:28:57 +08:00
)
logits_output , can_run_cuda_graph = (
2025-10-07 20:12:12 +08:00
batch_result . logits_output ,
batch_result . can_run_cuda_graph ,
2025-01-02 19:22:34 +08:00
)
2025-05-22 08:18:41 +08:00
vocab_mask = None
if batch . has_grammar :
# Generate the logit mask for structured output.
# Overlap the CPU operations for bitmask generation with the forward pass.
vocab_mask = generate_token_bitmask (
batch . reqs ,
spec_info ,
retrieve_next_token_cpu ,
retrieve_next_sibling_cpu ,
draft_tokens_cpu ,
batch . sampling_info . vocab_size ,
)
if vocab_mask is not None :
assert spec_info . grammar is not None
vocab_mask = vocab_mask . to ( spec_info . retrive_next_token . device )
2025-06-16 03:04:29 -07:00
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
2025-05-22 08:18:41 +08:00
# and will be applied to produce wrong results
batch . sampling_info . vocab_mask = None
2025-03-05 08:06:07 -08:00
self . _detect_nan_if_needed ( logits_output )
2025-02-03 20:52:30 +08:00
spec_info . hidden_states = logits_output . hidden_states
2025-03-05 08:06:07 -08:00
res : EagleVerifyOutput = spec_info . verify (
2025-03-30 00:46:23 -07:00
batch ,
logits_output ,
self . token_to_kv_pool_allocator ,
self . page_size ,
2025-05-22 08:18:41 +08:00
vocab_mask ,
2025-03-05 08:06:07 -08:00
)
# Post process based on verified outputs.
2025-05-11 00:55:00 -04:00
# Pick indices that we care (accepted)
2025-03-05 08:06:07 -08:00
logits_output . next_token_logits = logits_output . next_token_logits [
2025-05-11 00:55:00 -04:00
res . accepted_indices
2025-03-05 08:06:07 -08:00
]
2025-05-11 00:55:00 -04:00
logits_output . hidden_states = logits_output . hidden_states [ res . accepted_indices ]
2025-03-16 02:48:55 -07:00
2025-09-11 19:11:49 +08:00
# QQ: can be optimized
2025-10-08 19:37:38 +03:00
if self . target_worker . model_runner . hybrid_gdn_config is not None :
2025-09-11 19:11:49 +08:00
# res.draft_input.accept_length is on GPU but may be empty for last verify?
accepted_length = (
torch . tensor (
res . accept_length_per_req_cpu ,
device = logits_output . hidden_states . device ,
dtype = torch . int32 ,
)
+ 1
)
2025-09-11 11:59:48 -07:00
self . target_worker . model_runner . attn_backend . update_mamba_state_after_mtp_verify (
accepted_length , self . target_worker . model_runner . model
)
2025-09-11 19:11:49 +08:00
2025-06-16 03:04:29 -07:00
if batch . return_logprob :
self . add_logprob_values ( batch , res , logits_output )
2025-03-05 08:06:07 -08:00
# Prepare the batch for the next draft forwards.
2025-06-17 15:33:28 +08:00
batch . forward_mode = (
ForwardMode . DECODE if not batch . forward_mode . is_idle ( ) else ForwardMode . IDLE
)
2025-03-05 08:06:07 -08:00
batch . spec_info = res . draft_input
2025-01-02 19:22:34 +08:00
2025-05-12 00:17:33 -07:00
return logits_output , res , model_worker_batch , can_run_cuda_graph
2025-03-05 08:06:07 -08:00
2025-03-07 22:12:13 -08:00
def add_logprob_values (
self ,
batch : ScheduleBatch ,
res : EagleVerifyOutput ,
logits_output : LogitsProcessorOutput ,
) :
# Extract args
logits_output = res . logits_output
top_logprobs_nums = batch . top_logprobs_nums
token_ids_logprobs = batch . token_ids_logprobs
2025-06-16 03:04:29 -07:00
accepted_indices = res . accepted_indices
assert len ( accepted_indices ) == len ( logits_output . next_token_logits )
2025-08-29 11:43:57 -07:00
2025-06-16 03:04:29 -07:00
temperatures = batch . sampling_info . temperatures
num_draft_tokens = batch . spec_info . draft_token_num
# acceptance indices are the indices in a "flattened" batch.
# dividing it to num_draft_tokens will yield the actual batch index.
temperatures = temperatures [ accepted_indices / / num_draft_tokens ]
2025-08-29 11:43:57 -07:00
if RETURN_ORIGINAL_LOGPROB :
logprobs = torch . nn . functional . log_softmax (
logits_output . next_token_logits , dim = - 1
)
else :
logprobs = torch . nn . functional . log_softmax (
logits_output . next_token_logits / temperatures , dim = - 1
)
2025-03-07 22:12:13 -08:00
batch_next_token_ids = res . verified_id
num_tokens_per_req = [ accept + 1 for accept in res . accept_length_per_req_cpu ]
# We should repeat top_logprobs_nums to match num_tokens_per_req.
top_logprobs_nums_repeat_interleaved = [ ]
token_ids_logprobs_repeat_interleaved = [ ]
for num , num_tokens in zip ( top_logprobs_nums , num_tokens_per_req ) :
top_logprobs_nums_repeat_interleaved . extend ( [ num ] * num_tokens )
for token_ids , num_tokens in zip ( token_ids_logprobs , num_tokens_per_req ) :
token_ids_logprobs_repeat_interleaved . extend ( [ token_ids ] * num_tokens )
# Extract logprobs
if any ( x > 0 for x in top_logprobs_nums ) :
(
logits_output . next_token_top_logprobs_val ,
logits_output . next_token_top_logprobs_idx ,
2025-08-29 11:43:57 -07:00
) = get_top_logprobs (
logprobs ,
top_logprobs_nums_repeat_interleaved ,
)
2025-03-07 22:12:13 -08:00
if any ( x is not None for x in token_ids_logprobs ) :
(
logits_output . next_token_token_ids_logprobs_val ,
logits_output . next_token_token_ids_logprobs_idx ,
2025-08-29 11:43:57 -07:00
) = get_token_ids_logprobs (
logprobs ,
token_ids_logprobs_repeat_interleaved ,
)
2025-03-07 22:12:13 -08:00
logits_output . next_token_logprobs = logprobs [
torch . arange ( len ( batch_next_token_ids ) , device = batch . sampling_info . device ) ,
batch_next_token_ids ,
]
2025-03-16 02:48:55 -07:00
# Add output logprobs to the request
2025-03-07 22:12:13 -08:00
pt = 0
next_token_logprobs = logits_output . next_token_logprobs . tolist ( )
verified_ids = batch_next_token_ids . tolist ( )
2025-06-16 03:04:29 -07:00
for req , num_tokens in zip ( batch . reqs , num_tokens_per_req , strict = True ) :
2025-03-07 22:12:13 -08:00
for _ in range ( num_tokens ) :
if req . return_logprob :
req . output_token_logprobs_val . append ( next_token_logprobs [ pt ] )
req . output_token_logprobs_idx . append ( verified_ids [ pt ] )
if req . top_logprobs_num > 0 :
req . output_top_logprobs_val . append (
res . logits_output . next_token_top_logprobs_val [ pt ]
)
req . output_top_logprobs_idx . append (
res . logits_output . next_token_top_logprobs_idx [ pt ]
)
pt + = 1
2025-03-05 08:06:07 -08:00
def forward_draft_extend (
self ,
batch : ScheduleBatch ,
hidden_states : torch . Tensor ,
2025-07-24 21:36:21 -07:00
next_token_ids : torch . Tensor ,
seq_lens_cpu : Optional [ torch . Tensor ] ,
2025-03-05 08:06:07 -08:00
) :
""" Run draft model extend. This API modifies the states of the batch.
Args :
batch : The batch to run .
hidden_states : Hidden states from the target model forward
next_token_ids : Next token ids generated from the target forward .
"""
batch . spec_info = EagleDraftInput (
hidden_states = hidden_states ,
verified_id = next_token_ids ,
2025-07-24 21:36:21 -07:00
num_tokens_per_batch = 1 ,
num_tokens_for_logprob_per_batch = 1 ,
2025-03-05 08:06:07 -08:00
)
2025-06-10 17:37:29 -04:00
batch . return_hidden_states = False
2025-02-03 20:52:30 +08:00
batch . spec_info . prepare_for_extend ( batch )
batch . spec_info . capture_hidden_mode = CaptureHiddenMode . LAST
2025-06-09 16:39:21 -07:00
model_worker_batch = batch . get_model_worker_batch (
seq_lens_cpu_cache = seq_lens_cpu
)
2025-03-05 08:06:07 -08:00
forward_batch = ForwardBatch . init_new (
model_worker_batch , self . draft_model_runner
)
2025-03-06 06:13:59 -08:00
forward_batch . return_logprob = False
2025-05-12 00:17:33 -07:00
logits_output , _ = self . draft_model_runner . forward ( forward_batch )
2025-03-05 08:06:07 -08:00
self . _detect_nan_if_needed ( logits_output )
assert isinstance ( forward_batch . spec_info , EagleDraftInput )
assert forward_batch . spec_info is batch . spec_info
self . capture_for_decode ( logits_output , forward_batch . spec_info )
2025-08-17 08:45:36 +08:00
has_finished , unfinished_req_index = False , [ ]
for i , req in enumerate ( batch . reqs ) :
if req . finished ( ) :
has_finished = True
else :
unfinished_req_index . append ( i )
if has_finished :
unfinished_index_device = torch . tensor (
unfinished_req_index ,
dtype = torch . int64 ,
device = batch . spec_info . topk_p . device ,
)
batch . spec_info . filter_batch (
unfinished_index_device , has_been_filtered = False
)
2025-01-02 19:22:34 +08:00
2025-06-24 08:34:13 +08:00
def forward_draft_extend_after_decode ( self , batch : ScheduleBatch ) :
2025-07-24 21:36:21 -07:00
assert isinstance ( batch . spec_info , EagleDraftInput )
2025-05-11 00:55:00 -04:00
# Backup fields that will be modified in-place
2025-03-16 02:48:55 -07:00
seq_lens_backup = batch . seq_lens . clone ( )
2025-10-02 10:51:25 +08:00
seq_lens_cpu_backup = batch . seq_lens_cpu . clone ( )
2025-03-16 02:48:55 -07:00
req_pool_indices_backup = batch . req_pool_indices
accept_length_backup = batch . spec_info . accept_length
return_logprob_backup = batch . return_logprob
2025-07-24 21:36:21 -07:00
2025-06-17 15:33:28 +08:00
input_is_idle = batch . forward_mode . is_idle ( )
2025-07-24 21:36:21 -07:00
if not input_is_idle and batch . spec_info . verified_id . numel ( ) == 0 :
batch = batch . copy ( )
batch . prepare_for_idle ( )
hidden_size = (
self . model_config . hidden_size * 3
if self . speculative_algorithm . is_eagle3 ( )
else self . model_config . hidden_size
)
batch . spec_info = EagleDraftInput . create_idle_input (
device = self . device ,
hidden_size = hidden_size ,
dtype = self . model_config . dtype ,
topk = self . topk ,
capture_hidden_mode = CaptureHiddenMode . LAST ,
)
batch . spec_info . num_tokens_per_batch = self . speculative_num_steps + 1
batch . spec_info . num_tokens_for_logprob_per_batch = 1
batch . spec_info . prepare_extend_after_decode (
batch ,
self . speculative_num_steps ,
)
batch . forward_mode = (
ForwardMode . DRAFT_EXTEND
if not batch . forward_mode . is_idle ( )
else ForwardMode . IDLE
)
2025-06-17 15:33:28 +08:00
batch . return_hidden_states = False
2025-01-02 19:22:34 +08:00
model_worker_batch = batch . get_model_worker_batch ( )
2025-06-17 15:33:28 +08:00
assert model_worker_batch . capture_hidden_mode == CaptureHiddenMode . LAST
2025-03-05 08:06:07 -08:00
forward_batch = ForwardBatch . init_new (
model_worker_batch , self . draft_model_runner
)
2025-06-09 16:39:21 -07:00
if forward_batch . seq_lens_cpu is not None :
forward_batch . seq_lens_sum = forward_batch . seq_lens_cpu . sum ( ) . item ( )
else :
forward_batch . seq_lens_sum = batch . seq_lens . sum ( ) . item ( )
2025-03-16 02:48:55 -07:00
# Run
2025-05-27 17:35:17 +08:00
can_cuda_graph = (
2025-06-24 08:34:13 +08:00
self . cuda_graph_runner_for_draft_extend
2025-05-27 17:35:17 +08:00
and self . cuda_graph_runner_for_draft_extend . can_run ( forward_batch )
)
if can_cuda_graph :
logits_output = self . cuda_graph_runner_for_draft_extend . replay (
forward_batch
)
2025-06-09 16:39:21 -07:00
forward_batch . spec_info . topk_p , forward_batch . spec_info . topk_index = (
logits_output . topk_p ,
logits_output . topk_index ,
)
forward_batch . spec_info . hidden_states = logits_output . hidden_states
2025-05-27 17:35:17 +08:00
else :
2025-07-24 21:36:21 -07:00
forward_batch . can_run_dp_cuda_graph = False
2025-06-17 15:33:28 +08:00
if not forward_batch . forward_mode . is_idle ( ) :
self . draft_model_runner . attn_backend . init_forward_metadata (
forward_batch
)
2025-07-24 21:36:21 -07:00
logits_output , _ = self . draft_model_runner . forward (
forward_batch , skip_attn_backend_init = True
2025-05-27 17:35:17 +08:00
)
2025-06-09 16:39:21 -07:00
self . capture_for_decode ( logits_output , forward_batch . spec_info )
2025-03-16 02:48:55 -07:00
2025-03-05 08:06:07 -08:00
self . _detect_nan_if_needed ( logits_output )
2025-01-02 19:22:34 +08:00
2025-01-20 20:25:13 -08:00
# Restore backup.
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
2025-06-17 15:33:28 +08:00
batch . forward_mode = (
ForwardMode . DECODE if not input_is_idle else ForwardMode . IDLE
)
2025-01-20 20:25:13 -08:00
batch . seq_lens = seq_lens_backup
2025-10-02 10:51:25 +08:00
batch . seq_lens_cpu = seq_lens_cpu_backup
2025-03-16 02:48:55 -07:00
batch . req_pool_indices = req_pool_indices_backup
batch . spec_info . accept_length = accept_length_backup
batch . return_logprob = return_logprob_backup
2025-01-20 20:25:13 -08:00
2025-01-06 14:54:18 -08:00
def capture_for_decode (
2025-03-05 08:06:07 -08:00
self , logits_output : LogitsProcessorOutput , draft_input : EagleDraftInput
2025-01-06 14:54:18 -08:00
) :
2025-02-03 20:52:30 +08:00
probs = torch . softmax ( logits_output . next_token_logits , dim = - 1 )
2025-03-05 08:06:07 -08:00
draft_input . topk_p , draft_input . topk_index = fast_topk ( probs , self . topk , dim = - 1 )
draft_input . hidden_states = logits_output . hidden_states
def _detect_nan_if_needed ( self , logits_output : LogitsProcessorOutput ) :
2025-03-06 06:13:59 -08:00
if self . enable_nan_detection :
2025-03-05 08:06:07 -08:00
logits = logits_output . next_token_logits
if torch . any ( torch . isnan ( logits ) ) :
2025-03-16 02:48:55 -07:00
logger . error ( " Detected errors during sampling! NaN in the logits. " )
2025-03-05 08:06:07 -08:00
raise ValueError ( " Detected errors during sampling! NaN in the logits. " )
2025-03-04 13:40:40 -08:00
def load_token_map ( token_map_path : str ) - > List [ int ] :
if not os . path . exists ( token_map_path ) :
cache_dir = snapshot_download (
os . path . dirname ( token_map_path ) ,
ignore_patterns = [ " *.bin " , " *.safetensors " ] ,
)
token_map_path = os . path . join ( cache_dir , os . path . basename ( token_map_path ) )
2025-03-27 23:10:32 -04:00
hot_token_id = torch . load ( token_map_path , weights_only = True )
2025-06-16 07:25:59 -07:00
return torch . tensor ( hot_token_id , dtype = torch . int64 )
2025-06-16 03:04:29 -07:00
@torch.compile ( dynamic = True )
def get_last_loc_large_page_size_top_k_1 (
req_to_token : torch . Tensor ,
req_pool_indices : torch . Tensor ,
seq_lens ,
speculative_num_steps : int ,
) :
prefix_lens = seq_lens
seq_lens = prefix_lens + speculative_num_steps
last_loc = get_last_loc (
req_to_token ,
req_pool_indices ,
prefix_lens ,
)
return prefix_lens , seq_lens , last_loc
2025-08-22 02:05:02 -07:00
# Disable torch.compile for this function because it will be
# even slower.
# @torch.compile(dynamic=True)
2025-06-16 03:04:29 -07:00
def get_last_loc_large_page_size_large_top_k (
req_to_token : torch . Tensor ,
req_pool_indices : torch . Tensor ,
seq_lens : torch . Tensor ,
speculative_num_steps : int ,
topk : int ,
page_size : int ,
) :
prefix_lens = seq_lens
last_page_lens = prefix_lens % page_size
num_new_pages_per_topk = (
last_page_lens + speculative_num_steps + page_size - 1
) / / page_size
seq_lens = prefix_lens / / page_size * page_size + num_new_pages_per_topk * (
page_size * topk
)
extend_lens = seq_lens - prefix_lens
last_loc = get_last_loc (
req_to_token ,
req_pool_indices ,
prefix_lens ,
)
return prefix_lens , seq_lens , last_loc , num_new_pages_per_topk , extend_lens