1014 lines
39 KiB
Python
1014 lines
39 KiB
Python
import logging
|
|
import os
|
|
import time
|
|
from contextlib import contextmanager
|
|
from typing import List, Optional, Tuple
|
|
|
|
import torch
|
|
from huggingface_hub import snapshot_download
|
|
|
|
from sglang.srt.distributed import (
|
|
GroupCoordinator,
|
|
get_tp_group,
|
|
patch_tensor_parallel_group,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
|
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
|
from sglang.srt.managers.scheduler import GenerationBatchResult
|
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
|
from sglang.srt.mem_cache.common import (
|
|
alloc_paged_token_slots_extend,
|
|
alloc_token_slots,
|
|
get_last_loc,
|
|
)
|
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
CaptureHiddenMode,
|
|
ForwardBatch,
|
|
ForwardMode,
|
|
)
|
|
from sglang.srt.server_args import ServerArgs
|
|
from sglang.srt.speculative.draft_utils import DraftBackendFactory
|
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
EAGLEDraftCudaGraphRunner,
|
|
)
|
|
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
|
EAGLEDraftExtendCudaGraphRunner,
|
|
)
|
|
from sglang.srt.speculative.eagle_info import (
|
|
EagleDraftInput,
|
|
EagleVerifyInput,
|
|
EagleVerifyOutput,
|
|
)
|
|
from sglang.srt.speculative.eagle_utils import (
|
|
build_tree_kernel_efficient,
|
|
organize_draft_results,
|
|
)
|
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
from sglang.srt.speculative.spec_utils import (
|
|
assign_draft_cache_locs,
|
|
fast_topk,
|
|
generate_token_bitmask,
|
|
select_top_k_tokens,
|
|
)
|
|
from sglang.srt.utils import (
|
|
empty_context,
|
|
get_available_gpu_memory,
|
|
get_bool_env_var,
|
|
is_blackwell,
|
|
is_cuda,
|
|
next_power_of_2,
|
|
)
|
|
|
|
if is_cuda():
|
|
from sgl_kernel import segment_packbits
|
|
|
|
logger = logging.getLogger(__name__)
|
|
RETURN_ORIGINAL_LOGPROB = get_bool_env_var("RETURN_ORIGINAL_LOGPROB")
|
|
|
|
|
|
@contextmanager
|
|
def draft_tp_context(tp_group: GroupCoordinator):
|
|
# Draft model doesn't use dp and has its own tp group.
|
|
# We disable mscclpp now because it doesn't support 2 comm groups.
|
|
with patch_tensor_parallel_group(tp_group):
|
|
yield
|
|
|
|
|
|
class EAGLEWorker(TpModelWorker):
|
|
|
|
def __init__(
|
|
self,
|
|
server_args: ServerArgs,
|
|
gpu_id: int,
|
|
tp_rank: int,
|
|
dp_rank: Optional[int],
|
|
moe_ep_rank: int,
|
|
nccl_port: int,
|
|
target_worker: TpModelWorker,
|
|
):
|
|
# Parse arguments
|
|
self.server_args = server_args
|
|
self.topk = server_args.speculative_eagle_topk
|
|
self.speculative_num_steps = server_args.speculative_num_steps
|
|
self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
|
|
self.enable_nan_detection = server_args.enable_nan_detection
|
|
self.gpu_id = gpu_id
|
|
self.device = server_args.device
|
|
self.target_worker = target_worker
|
|
self.page_size = server_args.page_size
|
|
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
|
server_args.speculative_algorithm
|
|
)
|
|
self.padded_static_len = -1
|
|
|
|
# Override the context length of the draft model to be the same as the target model.
|
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
|
|
|
# Do not capture cuda graph in `super().__init__()`
|
|
# It will be captured later.
|
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
|
server_args.disable_cuda_graph = True
|
|
# Share the allocator with a target worker.
|
|
# Draft and target worker own their own KV cache pools.
|
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
|
target_worker.get_memory_pool()
|
|
)
|
|
|
|
# Load hot token ids
|
|
if self.speculative_algorithm.is_eagle3():
|
|
if server_args.speculative_token_map is not None:
|
|
logger.warning(
|
|
"Speculative token map specified, but EAGLE3 models already have this. Ignoring the specified token map."
|
|
)
|
|
self.hot_token_id = None
|
|
elif server_args.speculative_token_map is not None:
|
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
|
server_args.json_model_override_args = (
|
|
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
|
)
|
|
else:
|
|
self.hot_token_id = None
|
|
|
|
# Init draft worker
|
|
with empty_context():
|
|
super().__init__(
|
|
server_args=server_args,
|
|
gpu_id=gpu_id,
|
|
tp_rank=tp_rank,
|
|
pp_rank=0, # FIXME
|
|
dp_rank=dp_rank,
|
|
moe_ep_rank=moe_ep_rank,
|
|
nccl_port=nccl_port,
|
|
is_draft_worker=True,
|
|
req_to_token_pool=self.req_to_token_pool,
|
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
|
)
|
|
|
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
|
|
|
if self.speculative_algorithm.is_eagle3():
|
|
# most cases EAGLE3 models don't share lm_head
|
|
# but some models (e.g. nvidia/gpt-oss-120b-Eagle3) shares
|
|
if (
|
|
hasattr(self.draft_model_runner.model, "load_lm_head_from_target")
|
|
and self.draft_model_runner.model.load_lm_head_from_target
|
|
):
|
|
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
|
else:
|
|
self.draft_model_runner.model.set_embed(embed)
|
|
|
|
# grab hot token ids
|
|
if self.draft_model_runner.model.hot_token_id is not None:
|
|
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
|
embed.device
|
|
)
|
|
|
|
else:
|
|
if self.hot_token_id is not None:
|
|
head = head.clone()
|
|
self.hot_token_id = self.hot_token_id.to(head.device)
|
|
head.data = head.data[self.hot_token_id]
|
|
|
|
# Share the embedding and lm_head
|
|
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
|
|
|
# Init attention backend and cuda graphs
|
|
self.draft_model_runner.server_args.disable_cuda_graph = (
|
|
backup_disable_cuda_graph
|
|
)
|
|
self.draft_tp_context = (
|
|
draft_tp_context if server_args.enable_dp_attention else empty_context
|
|
)
|
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
self.init_attention_backend()
|
|
self.init_cuda_graphs()
|
|
|
|
# Some dummy tensors
|
|
self.num_new_pages_per_topk = torch.empty(
|
|
(), dtype=torch.int64, device=self.device
|
|
)
|
|
self.extend_lens = torch.empty((), dtype=torch.int64, device=self.device)
|
|
|
|
def init_attention_backend(self):
|
|
# Create multi-step attn backends and cuda graph runners
|
|
|
|
self.has_prefill_wrapper_verify = False
|
|
self.draft_extend_attn_backend = None
|
|
|
|
draft_backend_factory = DraftBackendFactory(
|
|
self.server_args,
|
|
self.draft_model_runner,
|
|
self.topk,
|
|
self.speculative_num_steps,
|
|
)
|
|
|
|
# Initialize decode attention backend
|
|
self.draft_attn_backend = draft_backend_factory.create_decode_backend()
|
|
|
|
# Initialize draft extend attention backend (respects speculative_attention_mode setting)
|
|
self.draft_extend_attn_backend = (
|
|
draft_backend_factory.create_draft_extend_backend()
|
|
)
|
|
|
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
|
|
|
def init_cuda_graphs(self):
|
|
"""Capture cuda graphs."""
|
|
self.cuda_graph_runner = None
|
|
self.cuda_graph_runner_for_draft_extend = None
|
|
|
|
if self.server_args.disable_cuda_graph:
|
|
return
|
|
|
|
# Capture draft
|
|
tic = time.perf_counter()
|
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
|
)
|
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Capture draft cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
|
)
|
|
|
|
# Capture extend
|
|
if self.draft_extend_attn_backend:
|
|
tic = time.perf_counter()
|
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Capture draft extend cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
|
)
|
|
self.cuda_graph_runner_for_draft_extend = EAGLEDraftExtendCudaGraphRunner(
|
|
self
|
|
)
|
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
|
logger.info(
|
|
f"Capture draft extend cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
|
|
)
|
|
|
|
@property
|
|
def draft_model_runner(self):
|
|
return self.model_runner
|
|
|
|
def forward_batch_generation(self, batch: ScheduleBatch) -> GenerationBatchResult:
|
|
"""Run speculative decoding forward.
|
|
|
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed that
|
|
the final output batch have the same state as the input.
|
|
|
|
Args:
|
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
|
Returns:
|
|
A tuple of the final logit output of the target model, next tokens accepted,
|
|
the batch id (used for overlap schedule), and number of accepted tokens.
|
|
"""
|
|
if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
|
|
logits_output, next_token_ids, seq_lens_cpu = self.forward_target_extend(
|
|
batch
|
|
)
|
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
self.forward_draft_extend(
|
|
batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
|
|
)
|
|
return GenerationBatchResult(
|
|
logits_output=logits_output,
|
|
next_token_ids=next_token_ids,
|
|
num_accepted_tokens=0,
|
|
can_run_cuda_graph=False,
|
|
)
|
|
else:
|
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
spec_info = self.draft(batch)
|
|
logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
|
|
self.verify(batch, spec_info)
|
|
)
|
|
|
|
with self.draft_tp_context(self.draft_model_runner.tp_group):
|
|
# NOTE: We should use `check_forward_draft_extend_after_decode`
|
|
# when DP attention is enabled, but it is slow. Skip it for now.
|
|
if (
|
|
self.server_args.enable_dp_attention
|
|
or batch.spec_info.verified_id.shape[0] > 0
|
|
):
|
|
# decode is not finished
|
|
self.forward_draft_extend_after_decode(batch)
|
|
|
|
return GenerationBatchResult(
|
|
logits_output=logits_output,
|
|
next_token_ids=verify_output.verified_id,
|
|
num_accepted_tokens=sum(verify_output.accept_length_per_req_cpu),
|
|
can_run_cuda_graph=can_run_cuda_graph,
|
|
)
|
|
|
|
def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
|
local_need_forward = batch.spec_info.verified_id.shape[0] > 0
|
|
if not self.server_args.enable_dp_attention:
|
|
return local_need_forward
|
|
|
|
global_need_forward = torch.tensor(
|
|
[
|
|
(local_need_forward),
|
|
],
|
|
dtype=torch.int64,
|
|
)
|
|
torch.distributed.all_reduce(
|
|
global_need_forward, group=get_tp_group().cpu_group
|
|
)
|
|
global_need_forward_cnt = global_need_forward[0].item()
|
|
need_forward = global_need_forward_cnt > 0
|
|
return need_forward
|
|
|
|
def forward_target_extend(
|
|
self, batch: ScheduleBatch
|
|
) -> Tuple[LogitsProcessorOutput, torch.Tensor, int, Optional[torch.Tensor]]:
|
|
"""Run the target extend.
|
|
|
|
Args:
|
|
batch: The batch to run. States could be modified.
|
|
|
|
Returns:
|
|
logits_output: The output of logits. It will contain the full hidden states.
|
|
next_token_ids: Next token ids generated.
|
|
"""
|
|
# Forward with the target model and get hidden states.
|
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
|
model_worker_batch = batch.get_model_worker_batch()
|
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
batch_result = self.target_worker.forward_batch_generation(model_worker_batch)
|
|
logits_output, next_token_ids = (
|
|
batch_result.logits_output,
|
|
batch_result.next_token_ids,
|
|
)
|
|
return (
|
|
logits_output,
|
|
next_token_ids,
|
|
model_worker_batch.seq_lens_cpu,
|
|
)
|
|
|
|
def _draft_preprocess_decode(self, batch: ScheduleBatch):
|
|
# Parse args
|
|
num_seqs = batch.batch_size()
|
|
spec_info = batch.spec_info
|
|
|
|
# Accumulate penalty
|
|
if batch.sampling_info.penalizer_orchestrator.is_required:
|
|
# This is a relaxed version of penalties for speculative decoding.
|
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
|
spec_info.verified_id.to(torch.int64)
|
|
)
|
|
|
|
# Allocate cache locations
|
|
# Layout of the out_cache_loc
|
|
# [ topk 0 ] [ topk 1 ]
|
|
# [iter=0, iter=1, iter=2] [iter=0, iter=1, iter=2]
|
|
if self.page_size == 1:
|
|
out_cache_loc, token_to_kv_pool_state_backup = alloc_token_slots(
|
|
batch.tree_cache,
|
|
num_seqs * self.speculative_num_steps * self.topk,
|
|
backup_state=True,
|
|
)
|
|
else:
|
|
if self.topk == 1:
|
|
prefix_lens, seq_lens, last_loc = get_last_loc_large_page_size_top_k_1(
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.req_pool_indices,
|
|
batch.seq_lens,
|
|
self.speculative_num_steps,
|
|
)
|
|
prefix_lens_cpu = batch.seq_lens_cpu
|
|
seq_lens_cpu = batch.seq_lens_cpu + self.speculative_num_steps
|
|
extend_num_tokens = num_seqs * self.speculative_num_steps
|
|
else:
|
|
# In this case, the last partial page needs to be duplicated.
|
|
# KV cache layout in batch.req_to_token_pool.req_to_token:
|
|
#
|
|
# | -------- | -- xxxx .. | -- xxxx .. | -- xxxx .. |
|
|
# prefix top-k = 0 tok-k = 1 top-k = 2
|
|
#
|
|
# "-" means prefix tokens
|
|
# "x" means speculative draft tokens
|
|
# "." means padded tokens
|
|
|
|
# TODO(lmzheng): The current implementation is still a fake support
|
|
# for page size > 1. In the `assign_draft_cache_locs` below,
|
|
# we directly move the indices instead of the real kv cache.
|
|
# This only works when the kernel backend runs with page size = 1.
|
|
# If the kernel backend runs with page size > 1, we need to
|
|
# duplicate the real KV cache. The overhead of duplicating KV
|
|
# cache seems okay because the draft KV cache only has one layer.
|
|
# see a related copy operation in MHATokenToKVPool::move_kv_cache.
|
|
|
|
(
|
|
prefix_lens,
|
|
seq_lens,
|
|
last_loc,
|
|
self.num_new_pages_per_topk,
|
|
self.extend_lens,
|
|
) = get_last_loc_large_page_size_large_top_k(
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.req_pool_indices,
|
|
batch.seq_lens,
|
|
self.speculative_num_steps,
|
|
self.topk,
|
|
self.page_size,
|
|
)
|
|
prefix_lens_cpu = batch.seq_lens_cpu
|
|
last_page_lens = prefix_lens_cpu % self.page_size
|
|
num_new_pages_per_topk = (
|
|
last_page_lens + self.speculative_num_steps + self.page_size - 1
|
|
) // self.page_size
|
|
seq_lens_cpu = (
|
|
prefix_lens_cpu // self.page_size * self.page_size
|
|
+ num_new_pages_per_topk * (self.page_size * self.topk)
|
|
)
|
|
extend_num_tokens = torch.sum((seq_lens_cpu - prefix_lens_cpu)).item()
|
|
|
|
out_cache_loc, token_to_kv_pool_state_backup = (
|
|
alloc_paged_token_slots_extend(
|
|
batch.tree_cache,
|
|
prefix_lens,
|
|
prefix_lens_cpu,
|
|
seq_lens,
|
|
seq_lens_cpu,
|
|
last_loc,
|
|
extend_num_tokens,
|
|
backup_state=True,
|
|
)
|
|
)
|
|
|
|
assign_draft_cache_locs[(num_seqs,)](
|
|
batch.req_pool_indices,
|
|
batch.req_to_token_pool.req_to_token,
|
|
batch.seq_lens,
|
|
self.extend_lens,
|
|
self.num_new_pages_per_topk,
|
|
out_cache_loc,
|
|
batch.req_to_token_pool.req_to_token.shape[1],
|
|
self.topk,
|
|
self.speculative_num_steps,
|
|
self.page_size,
|
|
next_power_of_2(num_seqs),
|
|
next_power_of_2(self.speculative_num_steps),
|
|
)
|
|
|
|
if self.page_size > 1 and self.topk > 1:
|
|
# Remove padded slots
|
|
out_cache_loc = out_cache_loc[
|
|
: num_seqs * self.topk * self.speculative_num_steps
|
|
]
|
|
|
|
batch.out_cache_loc = out_cache_loc
|
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
|
batch.return_hidden_states = False
|
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
|
self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
|
|
|
|
def _draft_preprocess_idle(self, batch: ScheduleBatch):
|
|
batch.spec_info = EagleDraftInput.create_idle_input(
|
|
device=self.device,
|
|
hidden_size=self.model_config.hidden_size,
|
|
dtype=self.model_config.dtype,
|
|
topk=self.topk,
|
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
|
)
|
|
|
|
def draft(self, batch: ScheduleBatch):
|
|
# Parse args
|
|
if batch.forward_mode.is_idle():
|
|
self._draft_preprocess_idle(batch)
|
|
else:
|
|
self._draft_preprocess_decode(batch)
|
|
|
|
spec_info = batch.spec_info
|
|
assert isinstance(spec_info, EagleDraftInput)
|
|
|
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
|
spec_info.num_tokens_per_batch = self.topk
|
|
spec_info.num_tokens_for_logprob_per_batch = self.topk
|
|
batch.return_hidden_states = False
|
|
|
|
# Get forward batch
|
|
model_worker_batch = batch.get_model_worker_batch()
|
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
|
forward_batch = ForwardBatch.init_new(
|
|
model_worker_batch, self.draft_model_runner
|
|
)
|
|
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
|
forward_batch
|
|
)
|
|
if can_cuda_graph:
|
|
parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
|
|
forward_batch
|
|
)
|
|
else:
|
|
forward_batch.can_run_dp_cuda_graph = False
|
|
if not forward_batch.forward_mode.is_idle():
|
|
# Initialize attention backend
|
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
|
# Run forward steps
|
|
parent_list, top_scores_index, draft_tokens = self.draft_forward(
|
|
forward_batch
|
|
)
|
|
|
|
if batch.forward_mode.is_idle():
|
|
return EagleVerifyInput.create_idle_input(
|
|
self.topk,
|
|
self.speculative_num_steps,
|
|
self.speculative_num_draft_tokens,
|
|
)
|
|
|
|
(
|
|
tree_mask,
|
|
position,
|
|
retrive_index,
|
|
retrive_next_token,
|
|
retrive_next_sibling,
|
|
draft_tokens,
|
|
) = build_tree_kernel_efficient(
|
|
spec_info.verified_id,
|
|
parent_list,
|
|
top_scores_index,
|
|
draft_tokens,
|
|
batch.seq_lens,
|
|
batch.seq_lens_sum,
|
|
self.topk,
|
|
self.speculative_num_steps,
|
|
self.speculative_num_draft_tokens,
|
|
)
|
|
|
|
return EagleVerifyInput(
|
|
draft_token=draft_tokens,
|
|
custom_mask=tree_mask,
|
|
positions=position,
|
|
retrive_index=retrive_index,
|
|
retrive_next_token=retrive_next_token,
|
|
retrive_next_sibling=retrive_next_sibling,
|
|
retrive_cum_len=None,
|
|
spec_steps=self.speculative_num_steps,
|
|
topk=self.topk,
|
|
draft_token_num=self.server_args.speculative_num_draft_tokens,
|
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
|
seq_lens_sum=forward_batch.seq_lens_sum,
|
|
seq_lens_cpu=forward_batch.seq_lens_cpu,
|
|
)
|
|
|
|
def draft_forward(self, forward_batch: ForwardBatch):
|
|
# Parse args
|
|
spec_info = forward_batch.spec_info
|
|
assert isinstance(spec_info, EagleDraftInput)
|
|
out_cache_loc = forward_batch.out_cache_loc
|
|
topk_p, topk_index, hidden_states = (
|
|
spec_info.topk_p,
|
|
spec_info.topk_index,
|
|
spec_info.hidden_states,
|
|
)
|
|
if self.hot_token_id is not None:
|
|
topk_index = self.hot_token_id[topk_index]
|
|
|
|
out_cache_loc = out_cache_loc.reshape(
|
|
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
|
)
|
|
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
|
self.speculative_num_steps, -1
|
|
)
|
|
|
|
# Return values
|
|
score_list: List[torch.Tensor] = []
|
|
token_list: List[torch.Tensor] = []
|
|
parents_list: List[torch.Tensor] = []
|
|
|
|
# Forward multiple steps
|
|
scores = None
|
|
for i in range(self.speculative_num_steps):
|
|
input_ids, hidden_states, scores, tree_info = select_top_k_tokens(
|
|
i, topk_p, topk_index, hidden_states, scores, self.topk
|
|
)
|
|
score_list.append(tree_info[0])
|
|
token_list.append(tree_info[1])
|
|
parents_list.append(tree_info[2])
|
|
|
|
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
|
if i == self.speculative_num_steps - 1:
|
|
break
|
|
|
|
# Set inputs
|
|
forward_batch.input_ids = input_ids
|
|
# This is a temporary fix for the case that the user is using standalone
|
|
# speculative decoding and the draft model architecture is gpt-oss. gpt-oss
|
|
# rope kernel needs cache_loc to be contiguous.
|
|
if (
|
|
self.server_args.speculative_algorithm == "STANDALONE"
|
|
and self.model_config.hf_config.architectures[0] == "GptOssForCausalLM"
|
|
):
|
|
out_cache_loc = out_cache_loc.contiguous()
|
|
forward_batch.out_cache_loc = out_cache_loc[i]
|
|
forward_batch.positions.add_(1)
|
|
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
|
spec_info.hidden_states = hidden_states
|
|
|
|
# Run forward
|
|
logits_output, _ = self.draft_model_runner.forward(
|
|
forward_batch, skip_attn_backend_init=True
|
|
)
|
|
self._detect_nan_if_needed(logits_output)
|
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
|
if self.hot_token_id is not None:
|
|
topk_index = self.hot_token_id[topk_index]
|
|
hidden_states = logits_output.hidden_states
|
|
|
|
parent_list, top_scores_index, draft_tokens = organize_draft_results(
|
|
score_list, token_list, parents_list, self.speculative_num_draft_tokens
|
|
)
|
|
|
|
return parent_list, top_scores_index, draft_tokens
|
|
|
|
def clear_cache_pool(self):
|
|
self.model_runner.req_to_token_pool.clear()
|
|
self.model_runner.token_to_kv_pool_allocator.clear()
|
|
|
|
def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
|
|
spec_info.prepare_for_verify(batch, self.page_size)
|
|
batch.return_hidden_states = False
|
|
batch.forward_mode = (
|
|
ForwardMode.TARGET_VERIFY
|
|
if not batch.forward_mode.is_idle()
|
|
else ForwardMode.IDLE
|
|
)
|
|
batch.spec_info = spec_info
|
|
|
|
model_worker_batch = batch.get_model_worker_batch(
|
|
seq_lens_cpu_cache=spec_info.seq_lens_cpu
|
|
)
|
|
assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
|
|
|
|
if batch.has_grammar:
|
|
retrieve_next_token_cpu = spec_info.retrive_next_token.cpu()
|
|
retrieve_next_sibling_cpu = spec_info.retrive_next_sibling.cpu()
|
|
draft_tokens_cpu = spec_info.draft_token.view(
|
|
spec_info.retrive_next_token.shape
|
|
).cpu()
|
|
|
|
# Forward
|
|
batch_result = self.target_worker.forward_batch_generation(
|
|
model_worker_batch, is_verify=True
|
|
)
|
|
logits_output, can_run_cuda_graph = (
|
|
batch_result.logits_output,
|
|
batch_result.can_run_cuda_graph,
|
|
)
|
|
|
|
vocab_mask = None
|
|
if batch.has_grammar:
|
|
# Generate the logit mask for structured output.
|
|
# Overlap the CPU operations for bitmask generation with the forward pass.
|
|
vocab_mask = generate_token_bitmask(
|
|
batch.reqs,
|
|
spec_info,
|
|
retrieve_next_token_cpu,
|
|
retrieve_next_sibling_cpu,
|
|
draft_tokens_cpu,
|
|
batch.sampling_info.vocab_size,
|
|
)
|
|
|
|
if vocab_mask is not None:
|
|
assert spec_info.grammar is not None
|
|
vocab_mask = vocab_mask.to(spec_info.retrive_next_token.device)
|
|
# NOTE (sk): otherwise, this vocab mask will be the one from the previous extend stage
|
|
# and will be applied to produce wrong results
|
|
batch.sampling_info.vocab_mask = None
|
|
|
|
self._detect_nan_if_needed(logits_output)
|
|
spec_info.hidden_states = logits_output.hidden_states
|
|
res: EagleVerifyOutput = spec_info.verify(
|
|
batch,
|
|
logits_output,
|
|
self.token_to_kv_pool_allocator,
|
|
self.page_size,
|
|
vocab_mask,
|
|
)
|
|
|
|
# Post process based on verified outputs.
|
|
# Pick indices that we care (accepted)
|
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
|
res.accepted_indices
|
|
]
|
|
logits_output.hidden_states = logits_output.hidden_states[res.accepted_indices]
|
|
|
|
# QQ: can be optimized
|
|
if self.target_worker.model_runner.hybrid_gdn_config is not None:
|
|
# res.draft_input.accept_length is on GPU but may be empty for last verify?
|
|
accepted_length = (
|
|
torch.tensor(
|
|
res.accept_length_per_req_cpu,
|
|
device=logits_output.hidden_states.device,
|
|
dtype=torch.int32,
|
|
)
|
|
+ 1
|
|
)
|
|
self.target_worker.model_runner.attn_backend.update_mamba_state_after_mtp_verify(
|
|
accepted_length, self.target_worker.model_runner.model
|
|
)
|
|
|
|
if batch.return_logprob:
|
|
self.add_logprob_values(batch, res, logits_output)
|
|
|
|
# Prepare the batch for the next draft forwards.
|
|
batch.forward_mode = (
|
|
ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
|
|
)
|
|
batch.spec_info = res.draft_input
|
|
|
|
return logits_output, res, model_worker_batch, can_run_cuda_graph
|
|
|
|
def add_logprob_values(
|
|
self,
|
|
batch: ScheduleBatch,
|
|
res: EagleVerifyOutput,
|
|
logits_output: LogitsProcessorOutput,
|
|
):
|
|
# Extract args
|
|
logits_output = res.logits_output
|
|
top_logprobs_nums = batch.top_logprobs_nums
|
|
token_ids_logprobs = batch.token_ids_logprobs
|
|
accepted_indices = res.accepted_indices
|
|
assert len(accepted_indices) == len(logits_output.next_token_logits)
|
|
|
|
temperatures = batch.sampling_info.temperatures
|
|
num_draft_tokens = batch.spec_info.draft_token_num
|
|
# acceptance indices are the indices in a "flattened" batch.
|
|
# dividing it to num_draft_tokens will yield the actual batch index.
|
|
temperatures = temperatures[accepted_indices // num_draft_tokens]
|
|
if RETURN_ORIGINAL_LOGPROB:
|
|
logprobs = torch.nn.functional.log_softmax(
|
|
logits_output.next_token_logits, dim=-1
|
|
)
|
|
else:
|
|
logprobs = torch.nn.functional.log_softmax(
|
|
logits_output.next_token_logits / temperatures, dim=-1
|
|
)
|
|
batch_next_token_ids = res.verified_id
|
|
num_tokens_per_req = [accept + 1 for accept in res.accept_length_per_req_cpu]
|
|
|
|
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
|
top_logprobs_nums_repeat_interleaved = []
|
|
token_ids_logprobs_repeat_interleaved = []
|
|
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
|
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
|
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
|
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
|
|
|
# Extract logprobs
|
|
if any(x > 0 for x in top_logprobs_nums):
|
|
(
|
|
logits_output.next_token_top_logprobs_val,
|
|
logits_output.next_token_top_logprobs_idx,
|
|
) = get_top_logprobs(
|
|
logprobs,
|
|
top_logprobs_nums_repeat_interleaved,
|
|
)
|
|
|
|
if any(x is not None for x in token_ids_logprobs):
|
|
(
|
|
logits_output.next_token_token_ids_logprobs_val,
|
|
logits_output.next_token_token_ids_logprobs_idx,
|
|
) = get_token_ids_logprobs(
|
|
logprobs,
|
|
token_ids_logprobs_repeat_interleaved,
|
|
)
|
|
|
|
logits_output.next_token_logprobs = logprobs[
|
|
torch.arange(len(batch_next_token_ids), device=batch.sampling_info.device),
|
|
batch_next_token_ids,
|
|
]
|
|
|
|
# Add output logprobs to the request
|
|
pt = 0
|
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
|
verified_ids = batch_next_token_ids.tolist()
|
|
for req, num_tokens in zip(batch.reqs, num_tokens_per_req, strict=True):
|
|
for _ in range(num_tokens):
|
|
if req.return_logprob:
|
|
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
|
req.output_token_logprobs_idx.append(verified_ids[pt])
|
|
if req.top_logprobs_num > 0:
|
|
req.output_top_logprobs_val.append(
|
|
res.logits_output.next_token_top_logprobs_val[pt]
|
|
)
|
|
req.output_top_logprobs_idx.append(
|
|
res.logits_output.next_token_top_logprobs_idx[pt]
|
|
)
|
|
pt += 1
|
|
|
|
def forward_draft_extend(
|
|
self,
|
|
batch: ScheduleBatch,
|
|
hidden_states: torch.Tensor,
|
|
next_token_ids: torch.Tensor,
|
|
seq_lens_cpu: Optional[torch.Tensor],
|
|
):
|
|
"""Run draft model extend. This API modifies the states of the batch.
|
|
|
|
Args:
|
|
batch: The batch to run.
|
|
hidden_states: Hidden states from the target model forward
|
|
next_token_ids: Next token ids generated from the target forward.
|
|
"""
|
|
batch.spec_info = EagleDraftInput(
|
|
hidden_states=hidden_states,
|
|
verified_id=next_token_ids,
|
|
num_tokens_per_batch=1,
|
|
num_tokens_for_logprob_per_batch=1,
|
|
)
|
|
batch.return_hidden_states = False
|
|
batch.spec_info.prepare_for_extend(batch)
|
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
|
model_worker_batch = batch.get_model_worker_batch(
|
|
seq_lens_cpu_cache=seq_lens_cpu
|
|
)
|
|
forward_batch = ForwardBatch.init_new(
|
|
model_worker_batch, self.draft_model_runner
|
|
)
|
|
forward_batch.return_logprob = False
|
|
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
|
self._detect_nan_if_needed(logits_output)
|
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
|
assert forward_batch.spec_info is batch.spec_info
|
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
|
has_finished, unfinished_req_index = False, []
|
|
for i, req in enumerate(batch.reqs):
|
|
if req.finished():
|
|
has_finished = True
|
|
else:
|
|
unfinished_req_index.append(i)
|
|
if has_finished:
|
|
unfinished_index_device = torch.tensor(
|
|
unfinished_req_index,
|
|
dtype=torch.int64,
|
|
device=batch.spec_info.topk_p.device,
|
|
)
|
|
batch.spec_info.filter_batch(
|
|
unfinished_index_device, has_been_filtered=False
|
|
)
|
|
|
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
|
assert isinstance(batch.spec_info, EagleDraftInput)
|
|
# Backup fields that will be modified in-place
|
|
seq_lens_backup = batch.seq_lens.clone()
|
|
seq_lens_cpu_backup = batch.seq_lens_cpu.clone()
|
|
req_pool_indices_backup = batch.req_pool_indices
|
|
accept_length_backup = batch.spec_info.accept_length
|
|
return_logprob_backup = batch.return_logprob
|
|
|
|
input_is_idle = batch.forward_mode.is_idle()
|
|
|
|
if not input_is_idle and batch.spec_info.verified_id.numel() == 0:
|
|
batch = batch.copy()
|
|
batch.prepare_for_idle()
|
|
hidden_size = (
|
|
self.model_config.hidden_size * 3
|
|
if self.speculative_algorithm.is_eagle3()
|
|
else self.model_config.hidden_size
|
|
)
|
|
batch.spec_info = EagleDraftInput.create_idle_input(
|
|
device=self.device,
|
|
hidden_size=hidden_size,
|
|
dtype=self.model_config.dtype,
|
|
topk=self.topk,
|
|
capture_hidden_mode=CaptureHiddenMode.LAST,
|
|
)
|
|
|
|
batch.spec_info.num_tokens_per_batch = self.speculative_num_steps + 1
|
|
batch.spec_info.num_tokens_for_logprob_per_batch = 1
|
|
batch.spec_info.prepare_extend_after_decode(
|
|
batch,
|
|
self.speculative_num_steps,
|
|
)
|
|
batch.forward_mode = (
|
|
ForwardMode.DRAFT_EXTEND
|
|
if not batch.forward_mode.is_idle()
|
|
else ForwardMode.IDLE
|
|
)
|
|
|
|
batch.return_hidden_states = False
|
|
model_worker_batch = batch.get_model_worker_batch()
|
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
|
forward_batch = ForwardBatch.init_new(
|
|
model_worker_batch, self.draft_model_runner
|
|
)
|
|
if forward_batch.seq_lens_cpu is not None:
|
|
forward_batch.seq_lens_sum = forward_batch.seq_lens_cpu.sum().item()
|
|
else:
|
|
forward_batch.seq_lens_sum = batch.seq_lens.sum().item()
|
|
|
|
# Run
|
|
can_cuda_graph = (
|
|
self.cuda_graph_runner_for_draft_extend
|
|
and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch)
|
|
)
|
|
if can_cuda_graph:
|
|
logits_output = self.cuda_graph_runner_for_draft_extend.replay(
|
|
forward_batch
|
|
)
|
|
forward_batch.spec_info.topk_p, forward_batch.spec_info.topk_index = (
|
|
logits_output.topk_p,
|
|
logits_output.topk_index,
|
|
)
|
|
forward_batch.spec_info.hidden_states = logits_output.hidden_states
|
|
else:
|
|
forward_batch.can_run_dp_cuda_graph = False
|
|
if not forward_batch.forward_mode.is_idle():
|
|
self.draft_model_runner.attn_backend.init_forward_metadata(
|
|
forward_batch
|
|
)
|
|
logits_output, _ = self.draft_model_runner.forward(
|
|
forward_batch, skip_attn_backend_init=True
|
|
)
|
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
|
|
|
self._detect_nan_if_needed(logits_output)
|
|
|
|
# Restore backup.
|
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
|
batch.forward_mode = (
|
|
ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
|
|
)
|
|
batch.seq_lens = seq_lens_backup
|
|
batch.seq_lens_cpu = seq_lens_cpu_backup
|
|
batch.req_pool_indices = req_pool_indices_backup
|
|
batch.spec_info.accept_length = accept_length_backup
|
|
batch.return_logprob = return_logprob_backup
|
|
|
|
def capture_for_decode(
|
|
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
|
):
|
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
|
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
|
draft_input.hidden_states = logits_output.hidden_states
|
|
|
|
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
|
if self.enable_nan_detection:
|
|
logits = logits_output.next_token_logits
|
|
if torch.any(torch.isnan(logits)):
|
|
logger.error("Detected errors during sampling! NaN in the logits.")
|
|
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
|
|
|
|
|
def load_token_map(token_map_path: str) -> List[int]:
|
|
if not os.path.exists(token_map_path):
|
|
cache_dir = snapshot_download(
|
|
os.path.dirname(token_map_path),
|
|
ignore_patterns=["*.bin", "*.safetensors"],
|
|
)
|
|
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
|
hot_token_id = torch.load(token_map_path, weights_only=True)
|
|
return torch.tensor(hot_token_id, dtype=torch.int64)
|
|
|
|
|
|
@torch.compile(dynamic=True)
|
|
def get_last_loc_large_page_size_top_k_1(
|
|
req_to_token: torch.Tensor,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens,
|
|
speculative_num_steps: int,
|
|
):
|
|
prefix_lens = seq_lens
|
|
seq_lens = prefix_lens + speculative_num_steps
|
|
last_loc = get_last_loc(
|
|
req_to_token,
|
|
req_pool_indices,
|
|
prefix_lens,
|
|
)
|
|
return prefix_lens, seq_lens, last_loc
|
|
|
|
|
|
# Disable torch.compile for this function because it will be
|
|
# even slower.
|
|
# @torch.compile(dynamic=True)
|
|
def get_last_loc_large_page_size_large_top_k(
|
|
req_to_token: torch.Tensor,
|
|
req_pool_indices: torch.Tensor,
|
|
seq_lens: torch.Tensor,
|
|
speculative_num_steps: int,
|
|
topk: int,
|
|
page_size: int,
|
|
):
|
|
prefix_lens = seq_lens
|
|
last_page_lens = prefix_lens % page_size
|
|
num_new_pages_per_topk = (
|
|
last_page_lens + speculative_num_steps + page_size - 1
|
|
) // page_size
|
|
seq_lens = prefix_lens // page_size * page_size + num_new_pages_per_topk * (
|
|
page_size * topk
|
|
)
|
|
extend_lens = seq_lens - prefix_lens
|
|
last_loc = get_last_loc(
|
|
req_to_token,
|
|
req_pool_indices,
|
|
prefix_lens,
|
|
)
|
|
|
|
return prefix_lens, seq_lens, last_loc, num_new_pages_per_topk, extend_lens
|