Beta spec-overlap for EAGLE (#11398)
Co-authored-by: Lianmin Zheng <15100009+merrymercy@users.noreply.github.com> Co-authored-by: Hanming Lu <69857889+hanming-lu@users.noreply.github.com>
This commit is contained in:
@@ -55,6 +55,25 @@ class AttentionBackend(ABC):
|
||||
"""Get the fill value for padded seq lens. Typically, it is 0 or 1."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_verify_buffers_to_fill_after_draft(self):
|
||||
"""
|
||||
Return buffers of verify attention kernels that needs to be filled after draft.
|
||||
|
||||
Typically, these are tree mask and position buffers.
|
||||
"""
|
||||
return [None, None]
|
||||
|
||||
def update_verify_buffers_to_fill_after_draft(
|
||||
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
|
||||
):
|
||||
"""
|
||||
Update the buffers returned by get_verify_fill_after_draft_buffers if needed.
|
||||
|
||||
Here, we need to redo the computation of all metadata of the attention backend
|
||||
that depends on tree mask and position buffers.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -29,7 +29,6 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
get_int_env_var,
|
||||
|
||||
@@ -162,6 +162,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
# Initialize forward metadata
|
||||
self.forward_metadata: ForwardMetadata = None
|
||||
|
||||
self.cuda_graph_custom_mask = None
|
||||
|
||||
def get_num_kv_splits(
|
||||
self,
|
||||
num_kv_splits: torch.Tensor,
|
||||
@@ -755,6 +757,19 @@ class TritonAttnBackend(AttentionBackend):
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
|
||||
def get_verify_buffers_to_fill_after_draft(self):
|
||||
"""
|
||||
Return buffers for verify attention kernels that needs to be filled after draft.
|
||||
|
||||
Typically, these are tree mask and position buffers.
|
||||
"""
|
||||
return [self.cuda_graph_custom_mask, None]
|
||||
|
||||
def update_verify_buffers_to_fill_after_draft(
|
||||
self, spec_info: SpecInput, cuda_graph_bs: Optional[int]
|
||||
):
|
||||
pass
|
||||
|
||||
def forward_extend(
|
||||
self,
|
||||
q: torch.Tensor,
|
||||
|
||||
@@ -384,6 +384,7 @@ class LogitsProcessor(nn.Module):
|
||||
if (
|
||||
logits_metadata.forward_mode.is_decode_or_idle()
|
||||
or logits_metadata.forward_mode.is_target_verify()
|
||||
or logits_metadata.forward_mode.is_draft_extend_v2()
|
||||
):
|
||||
pruned_states = hidden_states
|
||||
if aux_hidden_states is not None:
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.utils import get_compiler_backend
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
||||
def _resolve_future_token_ids(input_ids, future_token_ids_map):
|
||||
@@ -27,6 +34,7 @@ class FutureMap:
|
||||
self,
|
||||
max_running_requests: int,
|
||||
device: torch.device,
|
||||
spec_algo: Optional[SpeculativeAlgorithm] = None,
|
||||
):
|
||||
self.future_ct = 0
|
||||
# A factor of 3 is used to avoid collision in the circular buffer.
|
||||
@@ -34,9 +42,51 @@ class FutureMap:
|
||||
# A factor of 5 is used to ensure the buffer is large enough.
|
||||
self.future_buffer_len = max_running_requests * 5
|
||||
self.device = device
|
||||
self.spec_algo = spec_algo
|
||||
self.buf_initialized = False
|
||||
|
||||
self.token_ids_buf = torch.empty(
|
||||
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
||||
if self.spec_algo.is_none():
|
||||
self.token_ids_buf = torch.empty(
|
||||
(self.future_buffer_len,), dtype=torch.int64, device=self.device
|
||||
)
|
||||
|
||||
def _lazy_init_buf(self, draft_input: EagleDraftInput):
|
||||
if self.buf_initialized or not self.spec_algo.is_eagle():
|
||||
return
|
||||
|
||||
self.buf_initialized = True
|
||||
|
||||
# get the template for each tensor
|
||||
topk_p0 = draft_input.topk_p[0]
|
||||
topk_index0 = draft_input.topk_index[0]
|
||||
hidden_states0 = draft_input.hidden_states[0]
|
||||
verified_id0 = draft_input.verified_id[0]
|
||||
new_seq_lens0 = draft_input.new_seq_lens[0]
|
||||
|
||||
self.topk_p_buf = torch.empty(
|
||||
(self.future_buffer_len, *topk_p0.shape),
|
||||
dtype=topk_p0.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.topk_index_buf = torch.empty(
|
||||
(self.future_buffer_len, *topk_index0.shape),
|
||||
dtype=topk_index0.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.hidden_states_buf = torch.empty(
|
||||
(self.future_buffer_len, *hidden_states0.shape),
|
||||
dtype=hidden_states0.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.verified_id_buf = torch.empty(
|
||||
(self.future_buffer_len, *verified_id0.shape),
|
||||
dtype=verified_id0.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.new_seq_lens_buf = torch.empty(
|
||||
(self.future_buffer_len, *new_seq_lens0.shape),
|
||||
dtype=new_seq_lens0.dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def alloc_future_indices(self, bs: int) -> FutureIndices:
|
||||
@@ -49,7 +99,32 @@ class FutureMap:
|
||||
return FutureIndices(indices=indices, interval=slice(start, end))
|
||||
|
||||
def resolve_future(self, model_worker_batch: ModelWorkerBatch):
|
||||
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
||||
if self.spec_algo.is_eagle():
|
||||
# TODO(lsyin): write future indices into spec_info.future_indices
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
if draft_input is None:
|
||||
# FIXME(lsyin): No future exists, only for prefill batch, not compatible with mixed mode
|
||||
return
|
||||
indices = draft_input.future_indices.indices
|
||||
draft_input.topk_p = self.topk_p_buf[indices]
|
||||
draft_input.topk_index = self.topk_index_buf[indices]
|
||||
draft_input.hidden_states = self.hidden_states_buf[indices]
|
||||
draft_input.verified_id = self.verified_id_buf[indices]
|
||||
draft_input.new_seq_lens = self.new_seq_lens_buf[indices]
|
||||
else:
|
||||
_resolve_future_token_ids(model_worker_batch.input_ids, self.token_ids_buf)
|
||||
|
||||
def store_to_map(self, future_indices: FutureIndices, next_token_ids: torch.Tensor):
|
||||
self.token_ids_buf[future_indices.interval] = next_token_ids
|
||||
def store_to_map(
|
||||
self, future_indices: FutureIndices, batch_result: GenerationBatchResult
|
||||
):
|
||||
intv = future_indices.interval
|
||||
if self.spec_algo.is_eagle():
|
||||
draft_input: EagleDraftInput = batch_result.next_draft_input
|
||||
self._lazy_init_buf(draft_input)
|
||||
self.topk_p_buf[intv] = draft_input.topk_p
|
||||
self.topk_index_buf[intv] = draft_input.topk_index
|
||||
self.hidden_states_buf[intv] = draft_input.hidden_states
|
||||
self.verified_id_buf[intv] = draft_input.verified_id
|
||||
self.new_seq_lens_buf[intv] = draft_input.new_seq_lens
|
||||
else:
|
||||
self.token_ids_buf[intv] = batch_result.next_token_ids
|
||||
|
||||
@@ -61,8 +61,12 @@ from sglang.srt.mem_cache.allocator import (
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.common import alloc_for_decode, alloc_for_extend
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.common import (
|
||||
alloc_for_decode,
|
||||
alloc_for_extend,
|
||||
alloc_token_slots,
|
||||
)
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.mem_cache.radix_cache import RadixKey
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
||||
@@ -71,6 +75,7 @@ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.utils import flatten_nested_list
|
||||
from sglang.srt.utils.common import next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
@@ -1067,6 +1072,38 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
def is_empty(self):
|
||||
return len(self.reqs) == 0
|
||||
|
||||
def allocate_for_eagle_v2(self):
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool
|
||||
|
||||
bs = self.batch_size()
|
||||
|
||||
assert self.spec_info.is_draft_input()
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
|
||||
# FIXME(lsyin): now implementation does not enable over-allocation
|
||||
# Now seq_lens and allocate_lens are correct
|
||||
self.maybe_wait_verify_done()
|
||||
|
||||
new_allocate_lens = self.seq_lens + EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
num_needed_tokens = (new_allocate_lens - draft_input.allocate_lens).sum().item()
|
||||
out_cache_loc = alloc_token_slots(self.tree_cache, num_needed_tokens)
|
||||
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
self.req_pool_indices,
|
||||
self.req_to_token_pool.req_to_token,
|
||||
draft_input.allocate_lens,
|
||||
new_allocate_lens,
|
||||
out_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
draft_input.allocate_lens = new_allocate_lens
|
||||
|
||||
# FIXME(lsyin): remove seq_lens_sum calculation
|
||||
self.seq_lens_cpu = self.seq_lens.cpu()
|
||||
self.seq_lens_sum = self.seq_lens_cpu.sum().item()
|
||||
|
||||
def prepare_encoder_info_extend(self, input_ids: List[int], seq_lens: List[int]):
|
||||
self.encoder_lens_cpu = []
|
||||
self.encoder_cached = []
|
||||
@@ -1507,15 +1544,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.model_config.vocab_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def is_v2_eagle(self):
|
||||
# FIXME: finally deprecate is_v2_eagle
|
||||
return self.enable_overlap and self.spec_algorithm.is_eagle()
|
||||
|
||||
def prepare_for_decode(self):
|
||||
self.forward_mode = ForwardMode.DECODE
|
||||
bs = len(self.reqs)
|
||||
|
||||
if (
|
||||
self.spec_algorithm.is_eagle()
|
||||
or self.spec_algorithm.is_standalone()
|
||||
or self.spec_algorithm.is_ngram()
|
||||
):
|
||||
if self.is_v2_eagle:
|
||||
# FIXME(lsyin): make this sync optional
|
||||
self.allocate_for_eagle_v2()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# if spec decoding is used, the decode batch is prepared inside
|
||||
# `forward_batch_speculative_generation` after running draft models.
|
||||
return
|
||||
@@ -1566,11 +1608,23 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
self.orig_seq_lens.add_(1)
|
||||
self.seq_lens_sum += bs
|
||||
|
||||
def maybe_wait_verify_done(self):
|
||||
if self.is_v2_eagle:
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
draft_input: EagleDraftInput = self.spec_info
|
||||
if draft_input.verify_done is not None:
|
||||
draft_input.verify_done.synchronize()
|
||||
|
||||
def filter_batch(
|
||||
self,
|
||||
chunked_req_to_exclude: Optional[Union[Req, List[Req]]] = None,
|
||||
keep_indices: Optional[List[int]] = None,
|
||||
):
|
||||
# FIXME(lsyin): used here to get the correct seq_lens
|
||||
# The batch has been launched but we need it verified to get correct next batch info
|
||||
self.maybe_wait_verify_done()
|
||||
|
||||
if keep_indices is None:
|
||||
if isinstance(chunked_req_to_exclude, Req):
|
||||
chunked_req_to_exclude = [chunked_req_to_exclude]
|
||||
@@ -1633,6 +1687,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
)
|
||||
|
||||
def merge_batch(self, other: "ScheduleBatch"):
|
||||
# NOTE: in v2 eagle mode, we do not need wait verify here because
|
||||
# 1) current batch is always prefill, whose seq_lens and allocate_lens are not a future
|
||||
# 2) other batch is always decode, which is finished in previous step
|
||||
|
||||
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
||||
# orchestrator.merge() depends on Batch.reqs during preparation of each penalizers, so it
|
||||
# needs to be called with pre-merged Batch.reqs.
|
||||
@@ -1757,6 +1815,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
|
||||
is_extend_in_batch=self.is_extend_in_batch,
|
||||
is_prefill_only=self.is_prefill_only,
|
||||
seq_lens_cpu=self.seq_lens_cpu,
|
||||
enable_overlap=self.enable_overlap,
|
||||
)
|
||||
|
||||
def _evict_tree_cache_if_needed(self, num_tokens: int):
|
||||
|
||||
@@ -148,13 +148,10 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
||||
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache
|
||||
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
PPProxyTensors,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
||||
from sglang.srt.parser.reasoning_parser import ReasoningParser
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.tracing.trace import (
|
||||
process_tracing_init,
|
||||
@@ -219,6 +216,14 @@ class GenerationBatchResult:
|
||||
forward_batch: Optional[ForwardBatch] = None
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
|
||||
# FIXME(lsyin): maybe move to <BetterPlace> ?
|
||||
# sync path: forward stream -> output processor
|
||||
accept_lens: Optional[torch.Tensor] = None
|
||||
last_batch_allocate_lens: Optional[torch.Tensor] = None
|
||||
|
||||
# relay path: forward stream -> next step forward
|
||||
next_draft_input: Optional[EagleDraftInput] = None
|
||||
|
||||
def copy_to_cpu(self, return_logprob: bool = False):
|
||||
"""Copy tensors to CPU in overlap scheduling.
|
||||
Only the tensors which are needed for processing results are copied,
|
||||
@@ -238,6 +243,15 @@ class GenerationBatchResult:
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
self.next_token_ids = self.next_token_ids.to("cpu", non_blocking=True)
|
||||
|
||||
if self.accept_lens is not None:
|
||||
self.accept_lens = self.accept_lens.to("cpu", non_blocking=True)
|
||||
|
||||
if self.last_batch_allocate_lens is not None:
|
||||
self.last_batch_allocate_lens = self.last_batch_allocate_lens.to(
|
||||
"cpu", non_blocking=True
|
||||
)
|
||||
|
||||
self.copy_done.record()
|
||||
|
||||
@classmethod
|
||||
@@ -273,48 +287,6 @@ class Scheduler(
|
||||
):
|
||||
"""A scheduler that manages a tensor parallel GPU worker."""
|
||||
|
||||
def launch_draft_worker(
|
||||
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
):
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
self.draft_worker = EAGLEWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_standalone():
|
||||
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
||||
|
||||
self.draft_worker = StandaloneWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_ngram():
|
||||
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
||||
|
||||
self.draft_worker = NGRAMWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
@@ -454,6 +426,7 @@ class Scheduler(
|
||||
)
|
||||
|
||||
# Launch a draft worker for speculative decoding
|
||||
|
||||
self.launch_draft_worker(
|
||||
gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
)
|
||||
@@ -683,6 +656,51 @@ class Scheduler(
|
||||
]
|
||||
)
|
||||
|
||||
def launch_draft_worker(
|
||||
self, gpu_id, tp_rank, moe_ep_rank, server_args, port_args, dp_rank
|
||||
):
|
||||
if self.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
from sglang.srt.speculative.eagle_worker_v2 import EAGLEWorkerV2
|
||||
|
||||
WorkerClass = EAGLEWorkerV2 if self.enable_overlap else EAGLEWorker
|
||||
|
||||
self.draft_worker = WorkerClass(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_standalone():
|
||||
from sglang.srt.speculative.standalone_worker import StandaloneWorker
|
||||
|
||||
self.draft_worker = StandaloneWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
elif self.spec_algorithm.is_ngram():
|
||||
from sglang.srt.speculative.ngram_worker import NGRAMWorker
|
||||
|
||||
self.draft_worker = NGRAMWorker(
|
||||
gpu_id=gpu_id,
|
||||
tp_rank=tp_rank,
|
||||
moe_ep_rank=moe_ep_rank,
|
||||
server_args=server_args,
|
||||
nccl_port=port_args.nccl_port,
|
||||
target_worker=self.tp_worker,
|
||||
dp_rank=dp_rank,
|
||||
)
|
||||
else:
|
||||
self.draft_worker = None
|
||||
|
||||
def init_deterministic_inference_config(self):
|
||||
"""Initialize deterministic inference configuration for different attention backends."""
|
||||
if not self.server_args.enable_deterministic_inference:
|
||||
@@ -965,7 +983,9 @@ class Scheduler(
|
||||
self.device
|
||||
).stream(self.copy_stream)
|
||||
|
||||
self.future_map = FutureMap(self.max_running_requests, self.device)
|
||||
self.future_map = FutureMap(
|
||||
self.max_running_requests, self.device, self.spec_algorithm
|
||||
)
|
||||
self.batch_record_buf = [None] * 2
|
||||
self.batch_record_ct = 0
|
||||
|
||||
@@ -2096,7 +2116,7 @@ class Scheduler(
|
||||
|
||||
batch_or_worker_batch = batch
|
||||
|
||||
if self.spec_algorithm.is_none():
|
||||
if self.enable_overlap or self.spec_algorithm.is_none():
|
||||
# FIXME(lsyin): remove this if and finally unify the abstraction
|
||||
batch_or_worker_batch = batch.get_model_worker_batch()
|
||||
|
||||
@@ -2120,39 +2140,49 @@ class Scheduler(
|
||||
if batch.sampling_info.grammars is not None:
|
||||
model_worker_batch.delay_sample_launch = True
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
model_worker_batch
|
||||
)
|
||||
# FIXME(lsyin): maybe move this to forward_batch_generation
|
||||
batch_result.copy_done = torch.get_device_module(
|
||||
self.device
|
||||
).Event()
|
||||
if not model_worker_batch.delay_sample_launch:
|
||||
self.future_map.store_to_map(
|
||||
future_indices, batch_result.next_token_ids
|
||||
)
|
||||
self.future_map.store_to_map(future_indices, batch_result)
|
||||
batch_result.copy_to_cpu()
|
||||
else:
|
||||
batch_result.future_indices = future_indices
|
||||
|
||||
# FIXME(lsyin): move this assignment elsewhere
|
||||
maybe_future_next_token_ids = -future_indices.indices
|
||||
future_indices_or_next_token_ids = -future_indices.indices
|
||||
|
||||
if batch.is_v2_eagle:
|
||||
# FIXME(lsyin): tmp code for eagle v2
|
||||
# We only keep future indices for next draft input
|
||||
|
||||
batch.spec_info = batch_result.next_draft_input
|
||||
batch.spec_info.future_indices = future_indices
|
||||
|
||||
# batch.spec_info = EagleDraftInput(
|
||||
# future_indices=future_indices,
|
||||
# verify_done=batch_result.next_draft_input.verify_done,
|
||||
# # FIXME(lsyin): remove the allocate_lens in EagleDraftInput
|
||||
# allocate_lens=batch_result.next_draft_input.allocate_lens,
|
||||
# )
|
||||
|
||||
# The future value, usually for next batch preparation
|
||||
# Current implementation strictly synchronizes the seq_lens
|
||||
batch.seq_lens = batch_result.next_draft_input.new_seq_lens
|
||||
else:
|
||||
batch_result = self.model_worker.forward_batch_generation(
|
||||
batch_or_worker_batch
|
||||
)
|
||||
maybe_future_next_token_ids = batch_result.next_token_ids
|
||||
future_indices_or_next_token_ids = batch_result.next_token_ids
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
# TODO(lsyin): unify this metric-updating logic with non-spec, and move it to decode processing
|
||||
self.update_spec_metrics(
|
||||
batch.batch_size(), batch_result.num_accepted_tokens
|
||||
)
|
||||
|
||||
# NOTE: maybe_future_next_token_ids is used in ScheduleBatch,
|
||||
# NOTE: future_indices_or_next_token_ids is used in ScheduleBatch,
|
||||
# which can probably be replaced by future_indices later [TODO(lsyin)].
|
||||
# we shall still keep the original outputs, e.g. next_token_ids
|
||||
# in the GenerationBatchOutput for processing after copy_done.
|
||||
batch.output_ids = maybe_future_next_token_ids
|
||||
batch.output_ids = future_indices_or_next_token_ids
|
||||
|
||||
# These 2 values are needed for processing the output, but the values can be
|
||||
# modified by overlap schedule. So we have to copy them here so that
|
||||
@@ -2200,7 +2230,7 @@ class Scheduler(
|
||||
tmp_result.forward_batch,
|
||||
)
|
||||
future_indices = tmp_result.future_indices
|
||||
self.future_map.store_to_map(future_indices, tmp_result.next_token_ids)
|
||||
self.future_map.store_to_map(future_indices, tmp_result)
|
||||
tmp_result.copy_to_cpu()
|
||||
self.result_queue.appendleft((tmp_batch, tmp_result))
|
||||
|
||||
|
||||
@@ -69,7 +69,7 @@ class SchedulerMetricsMixin:
|
||||
kv_events_config, self.attn_dp_rank
|
||||
)
|
||||
|
||||
def update_spec_metrics(self, bs: int, num_accepted_tokens: int):
|
||||
def update_spec_metrics(self: Scheduler, bs: int, num_accepted_tokens: int):
|
||||
self.spec_num_total_accepted_tokens += num_accepted_tokens + bs
|
||||
self.spec_num_total_forward_ct += bs
|
||||
self.num_generated_tokens += num_accepted_tokens
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
@@ -200,6 +199,28 @@ class SchedulerOutputProcessorMixin:
|
||||
|
||||
self.stream_output(batch.reqs, batch.return_logprob, skip_stream_req)
|
||||
|
||||
def hacky_process_eagle_overlap_result(
|
||||
self: Scheduler, result: GenerationBatchResult, batch: ScheduleBatch
|
||||
):
|
||||
# TODO(lsyin): try use a copy stream to share SMs with forward
|
||||
# FIXME(lsyin): better organize this token free logic in eagle-overlap
|
||||
last_batch_allocate_lens_cpu = result.last_batch_allocate_lens.tolist()
|
||||
accept_lens_cpu = result.accept_lens.tolist()
|
||||
next_token_ids = result.next_token_ids.tolist()
|
||||
|
||||
predict_tokens = []
|
||||
num_draft_tokens = self.draft_worker.speculative_num_draft_tokens
|
||||
for i, req in enumerate(batch.reqs):
|
||||
predict_tokens.append(
|
||||
next_token_ids[
|
||||
i * num_draft_tokens : i * num_draft_tokens + accept_lens_cpu[i]
|
||||
]
|
||||
)
|
||||
# FIXME(lsyin): move this update elsewhere
|
||||
req.spec_verify_ct += 1
|
||||
|
||||
return last_batch_allocate_lens_cpu, accept_lens_cpu, predict_tokens
|
||||
|
||||
def process_batch_result_decode(
|
||||
self: Scheduler,
|
||||
batch: ScheduleBatch,
|
||||
@@ -220,6 +241,17 @@ class SchedulerOutputProcessorMixin:
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
elif batch.is_v2_eagle:
|
||||
(
|
||||
last_batch_allocate_lens_cpu,
|
||||
accept_lens_cpu,
|
||||
next_token_ids,
|
||||
) = self.hacky_process_eagle_overlap_result(result, batch)
|
||||
result.num_accepted_tokens = sum(accept_lens_cpu)
|
||||
|
||||
# FIXME(lsyin): we suppose we have already got the num_accepted_tokens in result
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.update_spec_metrics(batch.batch_size(), result.num_accepted_tokens)
|
||||
|
||||
self.token_to_kv_pool_allocator.free_group_begin()
|
||||
|
||||
@@ -227,29 +259,74 @@ class SchedulerOutputProcessorMixin:
|
||||
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
||||
# We should ignore using next_token_ids for spec decoding cases.
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
req: Req
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
# Free the one extra delayed token
|
||||
if self.page_size == 1:
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
else:
|
||||
# Only free when the extra token is in a new page
|
||||
if (
|
||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
) % self.page_size == 0:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
||||
free_spec_dec_tokens_page_size_1,
|
||||
)
|
||||
|
||||
free_spec_dec_tokens_page_size_1(
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
req,
|
||||
last_batch_allocate_lens_cpu[i],
|
||||
None,
|
||||
)
|
||||
else:
|
||||
# Free the one extra delayed token
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
batch.out_cache_loc[i : i + 1]
|
||||
)
|
||||
else:
|
||||
if batch.spec_algorithm.is_eagle():
|
||||
# TODO(lsyin): support eagle with page_size > 1
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
if (
|
||||
len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
) % self.page_size == 0:
|
||||
# Only free when the extra token is in a new page
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
batch.out_cache_loc[i : i + 1]
|
||||
)
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
# speculative worker will solve the output_ids in speculative decoding
|
||||
req.output_ids.append(next_token_id)
|
||||
elif batch.is_v2_eagle:
|
||||
# FIXME(lsyin): non-overlap spec worker will solve the output_ids in speculative decoding
|
||||
# !!!unify the logic here!!!
|
||||
req.output_ids.extend(next_token_id)
|
||||
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
if batch.is_v2_eagle and self.cur_batch.forward_mode.is_extend():
|
||||
# FIXME(lsyin): fix the messy logic here
|
||||
# 1) when not overlap (v2 impl), we free the extra tokens in the req
|
||||
# 2) when overlap and current batch is extend, we free the extra tokens in the req of the previous batch
|
||||
from sglang.srt.speculative.eagle_worker_v2 import (
|
||||
free_spec_dec_tokens_page_size_1,
|
||||
)
|
||||
|
||||
new_seq_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
||||
# FIXME(lsyin): remove this assert
|
||||
assert new_seq_len == int(
|
||||
batch.seq_lens_cpu[i] + accept_lens_cpu[i]
|
||||
), f"{new_seq_len=} vs {batch.seq_lens_cpu[i] + accept_lens_cpu[i]=}"
|
||||
|
||||
free_spec_dec_tokens_page_size_1(
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
req,
|
||||
last_batch_allocate_lens_cpu[i],
|
||||
new_seq_len,
|
||||
)
|
||||
|
||||
if self.server_args.disaggregation_decode_enable_offload_kvcache:
|
||||
# Asynchronously offload KV cache; cache_finished_req will be called after Device->Host transfer completes
|
||||
if not self.decode_offload_manager.offload_kv_cache(req):
|
||||
|
||||
@@ -231,12 +231,21 @@ class TpModelWorker:
|
||||
def forward_batch_generation(
|
||||
self,
|
||||
model_worker_batch: ModelWorkerBatch,
|
||||
forward_batch: Optional[ForwardBatch] = None,
|
||||
is_verify: bool = False,
|
||||
skip_attn_backend_init=False,
|
||||
) -> GenerationBatchResult:
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||
# FIXME(lsyin): maybe remove skip_attn_backend_init in forward_batch_generation,
|
||||
# which requires preparing replay to always be in this function
|
||||
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
if model_worker_batch is not None:
|
||||
# update the consumer index of hicache to the running batch
|
||||
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
||||
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
else:
|
||||
# FIXME(lsyin): unify the interface of forward_batch
|
||||
assert forward_batch is not None
|
||||
|
||||
pp_proxy_tensors = None
|
||||
if not self.pp_group.is_first_rank:
|
||||
@@ -248,7 +257,9 @@ class TpModelWorker:
|
||||
|
||||
if self.pp_group.is_last_rank:
|
||||
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
||||
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
||||
forward_batch,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
skip_attn_backend_init=skip_attn_backend_init,
|
||||
)
|
||||
batch_result = GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
@@ -290,6 +301,7 @@ class TpModelWorker:
|
||||
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
||||
forward_batch,
|
||||
pp_proxy_tensors=pp_proxy_tensors,
|
||||
skip_attn_backend_init=skip_attn_backend_init,
|
||||
)
|
||||
return GenerationBatchResult(
|
||||
pp_hidden_states_proxy_tensors=pp_proxy_tensors,
|
||||
|
||||
@@ -678,8 +678,9 @@ class CudaGraphRunner:
|
||||
capture_hidden_mode_required_by_forward_batch = (
|
||||
forward_batch.capture_hidden_mode
|
||||
)
|
||||
capture_hidden_mode_required_by_spec_info = getattr(
|
||||
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
||||
capture_hidden_mode_required_by_spec_info = (
|
||||
getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
||||
or CaptureHiddenMode.NULL
|
||||
)
|
||||
capture_hidden_mode_required_for_returning_hidden_states = (
|
||||
CaptureHiddenMode.FULL
|
||||
|
||||
@@ -75,6 +75,8 @@ class ForwardMode(IntEnum):
|
||||
# Used in speculative decoding: extend a batch in the draft model.
|
||||
DRAFT_EXTEND = auto()
|
||||
|
||||
DRAFT_EXTEND_V2 = auto()
|
||||
|
||||
# Split Prefill for PD multiplexing
|
||||
SPLIT_PREFILL = auto()
|
||||
|
||||
@@ -107,6 +109,10 @@ class ForwardMode(IntEnum):
|
||||
def is_draft_extend(self):
|
||||
return self == ForwardMode.DRAFT_EXTEND
|
||||
|
||||
def is_draft_extend_v2(self):
|
||||
# For fixed shape logits output in v2 eagle worker
|
||||
return self == ForwardMode.DRAFT_EXTEND_V2
|
||||
|
||||
def is_extend_or_draft_extend_or_mixed(self):
|
||||
return (
|
||||
self == ForwardMode.EXTEND
|
||||
|
||||
@@ -312,6 +312,7 @@ class ServerArgs:
|
||||
nsa_decode: str = "fa3"
|
||||
|
||||
# Speculative decoding
|
||||
enable_beta_spec: bool = False
|
||||
speculative_algorithm: Optional[str] = None
|
||||
speculative_draft_model_path: Optional[str] = None
|
||||
speculative_draft_model_revision: Optional[str] = None
|
||||
@@ -1103,11 +1104,19 @@ class ServerArgs:
|
||||
)
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 48
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning(
|
||||
"Overlap scheduler is disabled because of using "
|
||||
"eagle speculative decoding."
|
||||
)
|
||||
|
||||
if self.speculative_algorithm == "EAGLE" and self.enable_beta_spec:
|
||||
self.disable_overlap_schedule = False
|
||||
logger.warning(
|
||||
"Beta spec is enabled for eagle speculative decoding and overlap schedule is turned on."
|
||||
)
|
||||
|
||||
if not self.enable_beta_spec:
|
||||
self.disable_overlap_schedule = True
|
||||
logger.warning(
|
||||
"Overlap scheduler is disabled because of using eagle3 and standalone speculative decoding."
|
||||
)
|
||||
|
||||
if self.enable_mixed_chunk:
|
||||
self.enable_mixed_chunk = False
|
||||
logger.warning(
|
||||
@@ -2127,6 +2136,7 @@ class ServerArgs:
|
||||
)
|
||||
|
||||
# Speculative decoding
|
||||
parser.add_argument("--enable-beta-spec", action="store_true")
|
||||
parser.add_argument(
|
||||
"--speculative-algorithm",
|
||||
type=str,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import ClassVar, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -10,6 +10,7 @@ from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.overlap_utils import FutureIndices
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.common import (
|
||||
@@ -18,16 +19,20 @@ from sglang.srt.mem_cache.common import (
|
||||
get_last_loc,
|
||||
)
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.eagle_info_v2 import (
|
||||
EagleDraftInputV2Mixin,
|
||||
EagleVerifyInputV2Mixin,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
SIMULATE_ACC_LEN,
|
||||
TREE_SPEC_KERNEL_AVAILABLE,
|
||||
_generate_simulated_accept_index,
|
||||
align_evict_mask_to_page_size,
|
||||
assign_req_to_token_pool,
|
||||
create_accept_length_filter,
|
||||
create_extend_after_decode_spec_info,
|
||||
filter_finished_cache_loc_kernel,
|
||||
generate_simulated_accept_index,
|
||||
get_src_tgt_cache_loc,
|
||||
get_target_cache_loc,
|
||||
)
|
||||
@@ -47,7 +52,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyInput(SpecInput):
|
||||
class EagleVerifyInput(SpecInput, EagleVerifyInputV2Mixin):
|
||||
draft_token: torch.Tensor
|
||||
custom_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
@@ -338,7 +343,7 @@ class EagleVerifyInput(SpecInput):
|
||||
|
||||
if SIMULATE_ACC_LEN > 0.0:
|
||||
# Do simulation
|
||||
accept_index = _generate_simulated_accept_index(
|
||||
accept_index = generate_simulated_accept_index(
|
||||
accept_index=accept_index,
|
||||
predict=predict, # mutable
|
||||
accept_length=accept_length, # mutable
|
||||
@@ -568,7 +573,7 @@ class EagleVerifyInput(SpecInput):
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInput(SpecInput):
|
||||
class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
|
||||
# The inputs for decode
|
||||
# shape: (b, topk)
|
||||
topk_p: torch.Tensor = None
|
||||
@@ -598,6 +603,15 @@ class EagleDraftInput(SpecInput):
|
||||
seq_lens_for_draft_extend_cpu: torch.Tensor = None
|
||||
req_pool_indices_for_draft_extend: torch.Tensor = None
|
||||
|
||||
# Inputs for V2 overlap worker
|
||||
future_indices: Optional[FutureIndices] = None
|
||||
allocate_lens: Optional[torch.Tensor] = None
|
||||
new_seq_lens: Optional[torch.Tensor] = None
|
||||
verify_done: Optional[torch.cuda.Event] = None
|
||||
|
||||
# FIXME(lsyin): remove this hack
|
||||
ALLOC_LEN_PER_DECODE: ClassVar[int] = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__(SpecInputType.EAGLE_DRAFT)
|
||||
|
||||
@@ -703,6 +717,11 @@ class EagleDraftInput(SpecInput):
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||
if self.future_indices is not None:
|
||||
self.future_indices.indices = self.future_indices.indices[new_indices]
|
||||
self.allocate_lens = self.allocate_lens[new_indices]
|
||||
return
|
||||
|
||||
if has_been_filtered:
|
||||
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
||||
# therefore, we don't need to filter the batch again in scheduler
|
||||
@@ -722,6 +741,18 @@ class EagleDraftInput(SpecInput):
|
||||
self.verified_id = self.verified_id[new_indices]
|
||||
|
||||
def merge_batch(self, spec_info: "EagleDraftInput"):
|
||||
if self.future_indices is not None:
|
||||
assert spec_info.future_indices is not None
|
||||
self.future_indices = FutureIndices(
|
||||
indices=torch.cat(
|
||||
[self.future_indices.indices, spec_info.future_indices.indices]
|
||||
)
|
||||
)
|
||||
self.allocate_lens = torch.cat(
|
||||
[self.allocate_lens, spec_info.allocate_lens]
|
||||
)
|
||||
return
|
||||
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = spec_info.hidden_states
|
||||
self.verified_id = spec_info.verified_id
|
||||
|
||||
514
python/sglang/srt/speculative/eagle_info_v2.py
Normal file
514
python/sglang/srt/speculative/eagle_info_v2.py
Normal file
@@ -0,0 +1,514 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
from sglang.srt.managers.scheduler import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import (
|
||||
CaptureHiddenMode,
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
SIMULATE_ACC_LEN,
|
||||
generate_simulated_accept_index,
|
||||
)
|
||||
from sglang.srt.utils.common import fast_topk, is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
EAGLEDraftCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
top_k_renorm_prob,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
|
||||
@triton.jit
|
||||
def assign_draft_cache_locs_page_size_1(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
seq_lens,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
speculative_num_steps: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
copy_len = topk * speculative_num_steps
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
|
||||
# Copy from req_to_token to out_cache_loc
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = copy_offset < copy_len
|
||||
data = tl.load(token_pool + kv_start + copy_offset, mask=mask)
|
||||
tl.store(out_cache_ptr + copy_offset, data, mask=mask)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInputV2Mixin:
|
||||
def prepare_for_v2_draft(
|
||||
self: EagleDraftInput,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
batch: ModelWorkerBatch,
|
||||
cuda_graph_runner: EAGLEDraftCudaGraphRunner,
|
||||
draft_model_runner: ModelRunner,
|
||||
topk: int,
|
||||
num_steps: int,
|
||||
):
|
||||
bs = len(batch.seq_lens)
|
||||
|
||||
# Assign cache locations
|
||||
batch.out_cache_loc = torch.empty(
|
||||
(bs * topk * num_steps,),
|
||||
dtype=torch.int64,
|
||||
device=batch.input_ids.device,
|
||||
)
|
||||
# FIXME(lsyin): align with the default code path
|
||||
assign_draft_cache_locs_page_size_1[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.out_cache_loc,
|
||||
req_to_token_pool.req_to_token.shape[1],
|
||||
topk,
|
||||
num_steps,
|
||||
)
|
||||
|
||||
# Get a forward batch
|
||||
batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.positions = batch.seq_lens.repeat_interleave(topk, dim=0)
|
||||
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
|
||||
can_cuda_graph = cuda_graph_runner and cuda_graph_runner.can_run(forward_batch)
|
||||
return forward_batch, can_cuda_graph
|
||||
|
||||
def prepare_for_extend_to_fill_draft_kvcache(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
predict: torch.Tensor,
|
||||
num_draft_tokens: int,
|
||||
draft_model_runner: Any,
|
||||
):
|
||||
seq_lens_cpu_backup = batch.seq_lens_cpu
|
||||
extend_num_tokens = len(batch.seq_lens) * num_draft_tokens
|
||||
|
||||
batch.spec_info = self
|
||||
batch.input_ids = predict
|
||||
batch.seq_lens = batch.seq_lens + num_draft_tokens
|
||||
batch.seq_lens_cpu = batch.seq_lens_cpu + num_draft_tokens
|
||||
batch.seq_lens_sum += extend_num_tokens
|
||||
batch.extend_seq_lens = [num_draft_tokens for _ in range(len(batch.seq_lens))]
|
||||
batch.extend_prefix_lens = seq_lens_cpu_backup.tolist()
|
||||
batch.extend_prefix_lens_cpu = seq_lens_cpu_backup
|
||||
batch.extend_num_tokens = extend_num_tokens
|
||||
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND_V2
|
||||
forward_batch = ForwardBatch.init_new(batch, draft_model_runner)
|
||||
draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
|
||||
return forward_batch
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyInputV2Mixin:
|
||||
def prepare_for_v2_verify(
|
||||
self: EagleVerifyInput,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
batch: ModelWorkerBatch,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
# Assign cache locations
|
||||
bs = len(batch.req_pool_indices)
|
||||
batch.input_ids = self.draft_token
|
||||
device = batch.input_ids.device
|
||||
batch.out_cache_loc = torch.empty(
|
||||
(bs * self.draft_token_num,),
|
||||
dtype=torch.int64,
|
||||
device=device,
|
||||
)
|
||||
|
||||
assign_extend_cache_locs[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + self.draft_token_num,
|
||||
batch.out_cache_loc,
|
||||
req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
|
||||
# Get a forward batch
|
||||
batch.forward_mode = ForwardMode.TARGET_VERIFY
|
||||
batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
verify_forward_batch = ForwardBatch.init_new(batch, target_worker.model_runner)
|
||||
|
||||
# Run attention backend plan and cuda graph preparation
|
||||
can_run_cuda_graph = bool(
|
||||
target_worker.model_runner.graph_runner
|
||||
and target_worker.model_runner.graph_runner.can_run(verify_forward_batch)
|
||||
)
|
||||
if can_run_cuda_graph:
|
||||
target_worker.model_runner.graph_runner.replay_prepare(verify_forward_batch)
|
||||
else:
|
||||
target_worker.model_runner.attn_backend.init_forward_metadata(
|
||||
verify_forward_batch
|
||||
)
|
||||
|
||||
return verify_forward_batch, can_run_cuda_graph
|
||||
|
||||
def sample(
|
||||
self: EagleVerifyInput,
|
||||
batch: ModelWorkerBatch,
|
||||
logits_output: LogitsProcessorOutput,
|
||||
):
|
||||
"""
|
||||
Verify and find accepted tokens based on logits output and batch
|
||||
(which contains spec decoding information).
|
||||
"""
|
||||
bs = len(batch.seq_lens)
|
||||
sampling_info = batch.sampling_info
|
||||
next_token_logits = logits_output.next_token_logits
|
||||
device = batch.input_ids.device
|
||||
|
||||
candidates = self.draft_token.reshape(bs, self.draft_token_num)
|
||||
predict = torch.zeros(
|
||||
(bs * (self.spec_steps + 1),), dtype=torch.int32, device=device
|
||||
)
|
||||
accept_index = torch.full(
|
||||
(bs, self.spec_steps + 1), -1, dtype=torch.int32, device=device
|
||||
)
|
||||
accept_length = torch.empty((bs,), dtype=torch.int32, device=device)
|
||||
|
||||
# Sample tokens
|
||||
if sampling_info.is_all_greedy:
|
||||
target_predict = torch.argmax(next_token_logits, dim=-1)
|
||||
target_predict = target_predict.reshape(bs, self.draft_token_num)
|
||||
|
||||
verify_tree_greedy(
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
accept_token_num=accept_length, # mutable
|
||||
candidates=candidates,
|
||||
retrive_index=self.retrive_index,
|
||||
retrive_next_token=self.retrive_next_token,
|
||||
retrive_next_sibling=self.retrive_next_sibling,
|
||||
target_predict=target_predict,
|
||||
)
|
||||
else:
|
||||
# Apply temperature and get target probs
|
||||
expanded_temperature = torch.repeat_interleave(
|
||||
sampling_info.temperatures, self.draft_token_num, dim=0
|
||||
) # (bs * num_draft_tokens, 1)
|
||||
|
||||
target_probs = F.softmax(
|
||||
next_token_logits / expanded_temperature, dim=-1
|
||||
) # (bs * num_draft_tokens, vocab_size)
|
||||
target_probs = top_k_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(
|
||||
sampling_info.top_ks, self.draft_token_num, dim=0
|
||||
),
|
||||
) # (bs * num_draft_tokens, vocab_size)
|
||||
target_probs = top_p_renorm_prob(
|
||||
target_probs,
|
||||
torch.repeat_interleave(
|
||||
sampling_info.top_ps, self.draft_token_num, dim=0
|
||||
),
|
||||
)
|
||||
target_probs = target_probs.reshape(bs, self.draft_token_num, -1)
|
||||
|
||||
# This is currently not used
|
||||
draft_probs = torch.empty_like(target_probs)
|
||||
|
||||
# coins for rejection sampling
|
||||
coins = torch.rand_like(candidates, dtype=torch.float32, device=device)
|
||||
# coins for final sampling
|
||||
coins_for_final_sampling = torch.rand(
|
||||
(bs,), dtype=torch.float32, device=device
|
||||
)
|
||||
|
||||
tree_speculative_sampling_target_only(
|
||||
predicts=predict, # mutable
|
||||
accept_index=accept_index, # mutable
|
||||
accept_token_num=accept_length, # mutable
|
||||
candidates=candidates,
|
||||
retrive_index=self.retrive_index,
|
||||
retrive_next_token=self.retrive_next_token,
|
||||
retrive_next_sibling=self.retrive_next_sibling,
|
||||
uniform_samples=coins,
|
||||
uniform_samples_for_final_sampling=coins_for_final_sampling,
|
||||
target_probs=target_probs,
|
||||
draft_probs=draft_probs,
|
||||
threshold_single=global_server_args_dict[
|
||||
"speculative_accept_threshold_single"
|
||||
],
|
||||
threshold_acc=global_server_args_dict[
|
||||
"speculative_accept_threshold_acc"
|
||||
],
|
||||
deterministic=True,
|
||||
)
|
||||
|
||||
if SIMULATE_ACC_LEN > 0:
|
||||
# Do simulation
|
||||
accept_index = generate_simulated_accept_index(
|
||||
accept_index=accept_index,
|
||||
predict=predict, # mutable
|
||||
accept_length=accept_length, # mutable
|
||||
simulate_acc_len=SIMULATE_ACC_LEN,
|
||||
bs=bs,
|
||||
spec_steps=self.draft_token_num,
|
||||
)
|
||||
|
||||
# Include the bonus token
|
||||
accept_length.add_(1)
|
||||
return predict, accept_length, accept_index
|
||||
|
||||
|
||||
def build_tree_kernel_efficient_tmp(
|
||||
verified_id: torch.Tensor,
|
||||
parent_list: List[torch.Tensor],
|
||||
top_scores_index: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
seq_lens_sum: int,
|
||||
topk: int,
|
||||
spec_steps: int,
|
||||
num_verify_tokens: int,
|
||||
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
||||
tree_mask_buf: Optional[torch.Tensor] = None,
|
||||
position_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
# TODO(lsyin): make it compatible with default code path
|
||||
# TODO(lsyin): support cuda graph graph padding for eagle
|
||||
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
||||
|
||||
# seq_lens_sum == sum(seq_lens); seq_lens: sequence length without draft tokens
|
||||
bs = seq_lens.numel()
|
||||
device = seq_lens.device
|
||||
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
||||
# where each row indicates the attending pattern of each draft token
|
||||
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
||||
if tree_mask_buf is not None:
|
||||
tree_mask = tree_mask_buf
|
||||
if tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||
tree_mask.fill_(True)
|
||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||
tree_mask.fill_(0)
|
||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||
tree_mask.fill_(True)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
||||
tree_mask = torch.full(
|
||||
(num_verify_tokens * bs * num_verify_tokens,),
|
||||
True,
|
||||
dtype=torch.bool,
|
||||
device=device,
|
||||
)
|
||||
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
||||
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
||||
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
||||
tree_mask = torch.zeros(
|
||||
(num_verify_tokens * bs,),
|
||||
dtype=packed_dtypes[packed_dtype_idx],
|
||||
device=device,
|
||||
)
|
||||
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
||||
tree_mask = torch.full(
|
||||
(
|
||||
seq_lens_sum * num_verify_tokens
|
||||
+ num_verify_tokens * num_verify_tokens * bs,
|
||||
),
|
||||
True,
|
||||
device=device,
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
||||
|
||||
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
||||
retrive_buf = torch.full(
|
||||
(3, bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
||||
)
|
||||
retrive_index, retrive_next_token, retrive_next_sibling = retrive_buf
|
||||
# position: where each token belongs to
|
||||
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
||||
# then, positions = [7, 8, 8, 9]
|
||||
if position_buf is not None:
|
||||
positions = position_buf
|
||||
else:
|
||||
positions = torch.empty(
|
||||
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
||||
)
|
||||
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
|
||||
sgl_build_tree_kernel_efficient(
|
||||
parent_list,
|
||||
top_scores_index,
|
||||
seq_lens,
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
topk,
|
||||
spec_steps,
|
||||
num_verify_tokens,
|
||||
tree_mask_mode,
|
||||
)
|
||||
return (
|
||||
tree_mask,
|
||||
positions,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens_tmp(
|
||||
i: int,
|
||||
topk_p: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
topk: int,
|
||||
):
|
||||
# FIXME(lsyin): remove this duplicate code
|
||||
if i == 0:
|
||||
# The first step after extend
|
||||
input_ids = topk_index.flatten()
|
||||
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
|
||||
scores = topk_p # shape: (b, topk)
|
||||
|
||||
tree_info = (
|
||||
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
||||
topk_index, # shape: (b, topk)
|
||||
torch.arange(-1, topk, dtype=torch.long, device=hidden_states.device)
|
||||
.unsqueeze(0)
|
||||
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||
)
|
||||
else:
|
||||
# The later decode steps
|
||||
expand_scores = torch.mul(
|
||||
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||
topk_cs_p, topk_cs_index = fast_topk(
|
||||
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||
) # (b, topk)
|
||||
scores = topk_cs_p # shape: (b, topk)
|
||||
|
||||
topk_index = topk_index.reshape(-1, topk**2)
|
||||
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||
|
||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||
0, hidden_states.shape[0], step=topk, device=hidden_states.device
|
||||
).repeat_interleave(topk)
|
||||
hidden_states = hidden_states[selected_input_index, :]
|
||||
|
||||
tree_info = (
|
||||
expand_scores, # shape: (b, topk, topk)
|
||||
topk_index, # shape: (b, topk * topk)
|
||||
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
|
||||
)
|
||||
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fill_new_verified_id(
|
||||
verified_id,
|
||||
accept_lens,
|
||||
new_verified_id,
|
||||
num_draft_tokens: tl.constexpr,
|
||||
):
|
||||
# NOTE: we cannot fuse any in-place operations of `accept_lens` inside this kernel
|
||||
# because this kernel reads accept_lens
|
||||
pid = tl.program_id(axis=0)
|
||||
accept_length = tl.load(accept_lens + pid)
|
||||
|
||||
verified_id_idx = num_draft_tokens * pid + accept_length - 1
|
||||
verified_id_data = tl.load(verified_id + verified_id_idx)
|
||||
tl.store(new_verified_id + pid, verified_id_data)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def fill_accepted_out_cache_loc(
|
||||
accept_index,
|
||||
out_cache_loc,
|
||||
accepted_out_cache_loc,
|
||||
size_upper: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offset = tl.arange(0, size_upper)
|
||||
|
||||
masks = (tl.load(accept_index + offset, offset < pid, other=-1) != -1).to(tl.int64)
|
||||
dst = tl.sum(masks)
|
||||
src = tl.load(accept_index + pid)
|
||||
if src > -1:
|
||||
value = tl.load(out_cache_loc + src)
|
||||
tl.store(accepted_out_cache_loc + dst, value)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def assign_extend_cache_locs(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
start_offset,
|
||||
end_offset,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 32
|
||||
pid = tl.program_id(axis=0)
|
||||
kv_start = tl.load(start_offset + pid)
|
||||
kv_end = tl.load(end_offset + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
|
||||
length_offset = tl.arange(0, bs_upper)
|
||||
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
out_offset = tl.sum(end - start, axis=0)
|
||||
|
||||
out_cache_ptr = out_cache_loc + out_offset
|
||||
|
||||
load_offset = tl.arange(0, BLOCK_SIZE) + kv_start
|
||||
save_offset = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = load_offset < kv_end
|
||||
data = tl.load(token_pool + load_offset, mask=mask)
|
||||
tl.store(out_cache_ptr + save_offset, data, mask=mask)
|
||||
load_offset += BLOCK_SIZE
|
||||
save_offset += BLOCK_SIZE
|
||||
482
python/sglang/srt/speculative/eagle_worker_v2.py
Normal file
482
python/sglang/srt/speculative/eagle_worker_v2.py
Normal file
@@ -0,0 +1,482 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.cuda import Stream as CudaStream
|
||||
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, Req
|
||||
from sglang.srt.managers.scheduler import GenerationBatchResult
|
||||
from sglang.srt.managers.tp_worker import TpModelWorker
|
||||
from sglang.srt.mem_cache.allocator import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardBatch
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.build_eagle_tree import TreeMaskMode
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.eagle_info_v2 import (
|
||||
assign_extend_cache_locs,
|
||||
build_tree_kernel_efficient_tmp,
|
||||
fill_accepted_out_cache_loc,
|
||||
fill_new_verified_id,
|
||||
select_top_k_tokens_tmp,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
from sglang.srt.utils.common import fast_topk, next_power_of_2
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EAGLEWorkerV2(EAGLEWorker):
|
||||
def __init__(
|
||||
self,
|
||||
server_args: ServerArgs,
|
||||
gpu_id: int,
|
||||
tp_rank: int,
|
||||
dp_rank: Optional[int],
|
||||
moe_ep_rank: int,
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
super().__init__(
|
||||
server_args,
|
||||
gpu_id,
|
||||
tp_rank,
|
||||
dp_rank,
|
||||
moe_ep_rank,
|
||||
nccl_port,
|
||||
target_worker,
|
||||
)
|
||||
EagleDraftInput.ALLOC_LEN_PER_DECODE = max(
|
||||
self.speculative_num_steps * self.topk, self.speculative_num_draft_tokens
|
||||
)
|
||||
self.tree_mask_mode = TreeMaskMode.FULL_MASK
|
||||
self.plan_stream: CudaStream = torch.get_device_module(self.device).Stream()
|
||||
# TODO(lsyin): potential bugs with a separate plan stream
|
||||
self.plan_stream_ctx = torch.cuda.stream(self.plan_stream)
|
||||
|
||||
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
||||
if model_worker_batch.forward_mode.is_decode():
|
||||
# FIXME(lsyin): why shall we use spec_info for both draft and verify?
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
assert draft_input.is_draft_input()
|
||||
verify_input: EagleVerifyInput = self.draft(model_worker_batch)
|
||||
assert verify_input.is_verify_input()
|
||||
model_worker_batch.spec_info = verify_input
|
||||
batch_output = self.verify(model_worker_batch, draft_input.allocate_lens)
|
||||
return batch_output
|
||||
else:
|
||||
# Target prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
|
||||
# Draft prefill
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
batch_output.next_draft_input = self.forward_draft_extend(
|
||||
model_worker_batch,
|
||||
batch_output.logits_output.hidden_states,
|
||||
batch_output.next_token_ids,
|
||||
)
|
||||
return batch_output
|
||||
|
||||
def draft(self, model_worker_batch: ModelWorkerBatch):
|
||||
draft_input: EagleDraftInput = model_worker_batch.spec_info
|
||||
forward_batch, can_cuda_graph = draft_input.prepare_for_v2_draft(
|
||||
self.req_to_token_pool,
|
||||
model_worker_batch,
|
||||
self.cuda_graph_runner,
|
||||
self.draft_model_runner,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
|
||||
# Run draft
|
||||
if can_cuda_graph:
|
||||
parent_list, top_scores_index, draft_tokens = self.cuda_graph_runner.replay(
|
||||
forward_batch,
|
||||
)
|
||||
else:
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
parent_list, top_scores_index, draft_tokens = self.draft_forward(
|
||||
forward_batch
|
||||
)
|
||||
|
||||
# Build tree mask
|
||||
# Directly write to cuda graph buffers for verify attn
|
||||
tree_mask_buf, position_buf = (
|
||||
self.target_worker.model_runner.attn_backend.get_verify_buffers_to_fill_after_draft()
|
||||
)
|
||||
|
||||
(
|
||||
tree_mask,
|
||||
position,
|
||||
retrive_index,
|
||||
retrive_next_token,
|
||||
retrive_next_sibling,
|
||||
draft_tokens,
|
||||
) = build_tree_kernel_efficient_tmp(
|
||||
draft_input.verified_id,
|
||||
parent_list,
|
||||
top_scores_index,
|
||||
draft_tokens,
|
||||
model_worker_batch.seq_lens,
|
||||
model_worker_batch.seq_lens_sum,
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
self.speculative_num_draft_tokens,
|
||||
self.tree_mask_mode,
|
||||
tree_mask_buf,
|
||||
position_buf,
|
||||
)
|
||||
|
||||
return EagleVerifyInput(
|
||||
draft_token=draft_tokens,
|
||||
custom_mask=tree_mask,
|
||||
positions=position,
|
||||
retrive_index=retrive_index,
|
||||
retrive_next_token=retrive_next_token,
|
||||
retrive_next_sibling=retrive_next_sibling,
|
||||
retrive_cum_len=None,
|
||||
spec_steps=self.speculative_num_steps,
|
||||
topk=self.topk,
|
||||
draft_token_num=self.speculative_num_draft_tokens,
|
||||
capture_hidden_mode=None,
|
||||
seq_lens_sum=None,
|
||||
seq_lens_cpu=None,
|
||||
)
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
spec_info: EagleDraftInput = forward_batch.spec_info
|
||||
out_cache_loc = forward_batch.out_cache_loc
|
||||
topk_p, topk_index, hidden_states = (
|
||||
spec_info.topk_p,
|
||||
spec_info.topk_index,
|
||||
spec_info.hidden_states,
|
||||
)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
|
||||
out_cache_loc = out_cache_loc.reshape(
|
||||
forward_batch.batch_size, self.topk, self.speculative_num_steps
|
||||
)
|
||||
out_cache_loc = out_cache_loc.permute((2, 0, 1)).reshape(
|
||||
self.speculative_num_steps, -1
|
||||
)
|
||||
|
||||
# Return values
|
||||
score_list: List[torch.Tensor] = []
|
||||
token_list: List[torch.Tensor] = []
|
||||
parents_list: List[torch.Tensor] = []
|
||||
|
||||
# Forward multiple steps
|
||||
scores = None
|
||||
for i in range(self.speculative_num_steps):
|
||||
input_ids, hidden_states, scores, tree_info = select_top_k_tokens_tmp(
|
||||
i, topk_p, topk_index, hidden_states, scores, self.topk
|
||||
)
|
||||
score_list.append(tree_info[0])
|
||||
token_list.append(tree_info[1])
|
||||
parents_list.append(tree_info[2])
|
||||
|
||||
# We don't need to run the last forward. we get 1 token from draft prefill and (#spec steps - 1) tokens here
|
||||
if i == self.speculative_num_steps - 1:
|
||||
break
|
||||
|
||||
# Set inputs
|
||||
forward_batch.input_ids = input_ids
|
||||
forward_batch.out_cache_loc = out_cache_loc[i]
|
||||
forward_batch.positions.add_(1)
|
||||
forward_batch.attn_backend = self.draft_attn_backend.attn_backends[i]
|
||||
spec_info.hidden_states = hidden_states
|
||||
|
||||
# Run forward
|
||||
logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
topk_index = self.hot_token_id[topk_index]
|
||||
hidden_states = logits_output.hidden_states
|
||||
|
||||
# Organize the results
|
||||
score_list = torch.cat(score_list, dim=1).flatten(
|
||||
1
|
||||
) # b, n, topk; n= 1 + (num_steps-1) * self.topk
|
||||
ss_token_list = torch.cat(
|
||||
token_list, dim=1
|
||||
) # b, (self.topk + (num_steps-1) * self.topk)
|
||||
top_scores = torch.topk(
|
||||
score_list, self.speculative_num_draft_tokens - 1, dim=-1
|
||||
)
|
||||
top_scores_index = top_scores.indices
|
||||
top_scores_index = torch.sort(top_scores_index).values
|
||||
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
||||
|
||||
if len(parents_list) > 1:
|
||||
parent_list = torch.cat(parents_list[:-1], dim=1)
|
||||
else:
|
||||
batch_size = parents_list[0].shape[0]
|
||||
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
||||
|
||||
return parent_list, top_scores_index, draft_tokens
|
||||
|
||||
def verify(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
pre_draft_allocate_lens: torch.Tensor,
|
||||
):
|
||||
# Parse args
|
||||
verify_input: EagleVerifyInput = batch.spec_info
|
||||
seq_lens_backup = batch.seq_lens
|
||||
bs = len(batch.seq_lens)
|
||||
|
||||
# Batch 1: Target verify
|
||||
# Prepare for target verify in a separate stream
|
||||
with self.plan_stream_ctx:
|
||||
verify_forward_batch, can_run_cuda_graph = (
|
||||
verify_input.prepare_for_v2_verify(
|
||||
self.req_to_token_pool,
|
||||
batch,
|
||||
self.target_worker,
|
||||
)
|
||||
)
|
||||
|
||||
# Correct some buffers due to the overlap plan
|
||||
if self.plan_stream:
|
||||
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
||||
|
||||
# Some values such as custom_mask and position depend on the output of draft,
|
||||
# so the previous plan step used the wrong values. Here, we need to run the related
|
||||
# computation again to update them to the correct values.
|
||||
self.target_worker.model_runner.attn_backend.update_verify_buffers_to_fill_after_draft(
|
||||
verify_input,
|
||||
(
|
||||
self.target_worker.model_runner.graph_runner.bs
|
||||
if can_run_cuda_graph
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
# Run target verify batch in the main compute stream
|
||||
forward_batch_output = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch=None,
|
||||
forward_batch=verify_forward_batch,
|
||||
is_verify=True,
|
||||
skip_attn_backend_init=True,
|
||||
)
|
||||
logits_output = forward_batch_output.logits_output
|
||||
|
||||
# Sample
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
(
|
||||
predict,
|
||||
accept_length,
|
||||
accept_index,
|
||||
) = verify_input.sample(batch, logits_output)
|
||||
new_seq_lens = seq_lens_backup + accept_length
|
||||
verify_done = torch.cuda.Event()
|
||||
|
||||
# Move the accepted tokens to the target KV cache locations
|
||||
batch.seq_lens = seq_lens_backup
|
||||
self.move_accepted_tokens_to_target_kvcache(
|
||||
batch,
|
||||
accept_index,
|
||||
accept_length,
|
||||
)
|
||||
|
||||
verify_done.record()
|
||||
|
||||
all_verified_id = predict[accept_index]
|
||||
verified_id = torch.empty_like(accept_length, dtype=torch.int32)
|
||||
fill_new_verified_id[(bs,)](
|
||||
all_verified_id,
|
||||
accept_length,
|
||||
verified_id,
|
||||
self.speculative_num_draft_tokens,
|
||||
)
|
||||
|
||||
# Batch 2: Draft extend
|
||||
draft_input = EagleDraftInput(
|
||||
hidden_states=logits_output.hidden_states,
|
||||
)
|
||||
select_index = (
|
||||
torch.arange(len(batch.seq_lens), device=self.device)
|
||||
* self.speculative_num_draft_tokens
|
||||
+ accept_length
|
||||
- 1
|
||||
)
|
||||
|
||||
# Prepare for draft extend in a separate stream
|
||||
with self.plan_stream_ctx:
|
||||
forward_batch = draft_input.prepare_for_extend_to_fill_draft_kvcache(
|
||||
batch,
|
||||
predict,
|
||||
self.speculative_num_draft_tokens,
|
||||
self.draft_model_runner,
|
||||
)
|
||||
|
||||
if self.plan_stream:
|
||||
torch.cuda.current_stream().wait_stream(self.plan_stream)
|
||||
|
||||
# Run draft extend batch in the main compute stream
|
||||
draft_logits_output = self.draft_model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
|
||||
# Reorganize the spec info for the next batch
|
||||
draft_logits_output.next_token_logits = draft_logits_output.next_token_logits[
|
||||
select_index
|
||||
]
|
||||
draft_logits_output.hidden_states = draft_logits_output.hidden_states[
|
||||
select_index
|
||||
]
|
||||
probs = torch.softmax(draft_logits_output.next_token_logits, dim=-1)
|
||||
ret_topk_p, ret_topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
ret_hidden_states = draft_logits_output.hidden_states
|
||||
|
||||
# Since seq_lens_backup's tensor is allocated in another stream, we
|
||||
# need record_stream() to prevent pytorch gc and reuse the gpu memory
|
||||
# while forward_stream is still running.
|
||||
seq_lens_backup.record_stream(torch.cuda.current_stream())
|
||||
|
||||
# Construct the return values
|
||||
next_draft_input = EagleDraftInput(
|
||||
topk_p=ret_topk_p,
|
||||
topk_index=ret_topk_index,
|
||||
hidden_states=ret_hidden_states,
|
||||
verified_id=verified_id,
|
||||
new_seq_lens=new_seq_lens,
|
||||
allocate_lens=pre_draft_allocate_lens,
|
||||
verify_done=verify_done,
|
||||
)
|
||||
|
||||
return GenerationBatchResult(
|
||||
logits_output=logits_output,
|
||||
next_token_ids=predict,
|
||||
can_run_cuda_graph=can_run_cuda_graph,
|
||||
next_draft_input=next_draft_input,
|
||||
accept_lens=accept_length,
|
||||
last_batch_allocate_lens=pre_draft_allocate_lens,
|
||||
)
|
||||
|
||||
def forward_draft_extend(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
target_hidden_states: torch.Tensor,
|
||||
next_token_ids: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Run draft model extend to correctly fill the KV cache.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
target_hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
# Construct input_ids
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_seq_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], next_token_ids[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
# Construct spec_info
|
||||
next_draft_input = EagleDraftInput(
|
||||
hidden_states=target_hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
new_seq_lens=batch.seq_lens,
|
||||
allocate_lens=batch.seq_lens,
|
||||
)
|
||||
batch.spec_info = next_draft_input
|
||||
|
||||
# Run forward
|
||||
forward_batch = ForwardBatch.init_new(batch, self.draft_model_runner)
|
||||
logits_output, _ = self.draft_model_runner.forward(forward_batch)
|
||||
|
||||
# Update spec_info for the next draft step
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
next_draft_input.topk_p, next_draft_input.topk_index = fast_topk(
|
||||
probs, self.topk, dim=-1
|
||||
)
|
||||
next_draft_input.hidden_states = logits_output.hidden_states
|
||||
return next_draft_input
|
||||
|
||||
def move_accepted_tokens_to_target_kvcache(
|
||||
self,
|
||||
batch: ModelWorkerBatch,
|
||||
accept_index: torch.Tensor,
|
||||
accept_length: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Move accepted tokens to the target KV cache.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
accept_index: The index of the accepted tokens.
|
||||
accept_length: The length of the accepted tokens.
|
||||
"""
|
||||
bs = len(batch.seq_lens)
|
||||
size = bs * self.speculative_num_draft_tokens
|
||||
|
||||
tgt_cache_loc = torch.zeros(
|
||||
size,
|
||||
dtype=torch.int64,
|
||||
device=self.device,
|
||||
)
|
||||
accepted_out_cache_loc = torch.zeros(
|
||||
size, dtype=torch.int64, device=self.device
|
||||
)
|
||||
assign_extend_cache_locs[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
self.req_to_token_pool.req_to_token,
|
||||
batch.seq_lens,
|
||||
batch.seq_lens + accept_length,
|
||||
tgt_cache_loc,
|
||||
self.req_to_token_pool.req_to_token.shape[1],
|
||||
next_power_of_2(bs),
|
||||
)
|
||||
fill_accepted_out_cache_loc[(size,)](
|
||||
accept_index,
|
||||
batch.out_cache_loc,
|
||||
accepted_out_cache_loc,
|
||||
next_power_of_2(size),
|
||||
)
|
||||
self.token_to_kv_pool_allocator.get_kvcache().move_kv_cache(
|
||||
tgt_cache_loc, accepted_out_cache_loc
|
||||
)
|
||||
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.enable_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.error("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
def free_spec_dec_tokens_page_size_1(
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
req: Req,
|
||||
allocate_len: int,
|
||||
new_seq_len: int,
|
||||
):
|
||||
# FIXME(lsyin): move this function elsewhere
|
||||
|
||||
# free extra allocated tokens
|
||||
if new_seq_len is None:
|
||||
# True only for overlap eagle and the current batch is decode. This seq will be part of the decode, so the final iteration's allocation is not used (i.e. this case).
|
||||
start_len = allocate_len - EagleDraftInput.ALLOC_LEN_PER_DECODE
|
||||
else:
|
||||
# True for 1) non-overlap; 2) overlap eagle and the current batch is prefill. This seq will not run extra iteration, so start_lens is passed in.
|
||||
start_len = new_seq_len
|
||||
indices_to_free = req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
start_len:allocate_len
|
||||
]
|
||||
token_to_kv_pool_allocator.free(indices_to_free)
|
||||
@@ -435,7 +435,7 @@ def select_top_k_tokens(
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
def generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
accept_length,
|
||||
|
||||
@@ -4,7 +4,7 @@ import copy
|
||||
import dataclasses
|
||||
import logging
|
||||
from dataclasses import replace
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Union
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence
|
||||
|
||||
import torch
|
||||
|
||||
@@ -30,12 +30,12 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
_is_hip = is_hip()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user