Organize spec-related data structures (#10735)

This commit is contained in:
Liangsheng Yin
2025-10-01 09:45:30 +08:00
committed by GitHub
parent 7fb551a75d
commit 73d4a5f879
32 changed files with 959 additions and 923 deletions

View File

@@ -37,7 +37,7 @@ except ImportError:
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
logger = logging.getLogger(__name__)

View File

@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
# local import to avoid circular import
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.speculative.eagle_info import EagleDraftInput
spec_info = EagleDraftInput(
topk_p=topk_p,

View File

@@ -4,18 +4,13 @@ from __future__ import annotations
end to end attention solution with aiter kernels
"""
import math
import os
from dataclasses import dataclass
from enum import Enum, auto
from functools import partial
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Optional
import torch
import triton
import triton.language as tl
from sglang.global_config import global_config
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.dp_attention import (
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
try:
from aiter import (
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
qo_indptr = None
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
seq_lens_sum: int,
prefix_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
kv_start_idx = None
@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
extend_lens: torch.Tensor,
max_q_len: int,
max_kv_len: int,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
bs = len(req_pool_indices)
@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
import torch
import torch_npu
from torch.nn.functional import scaled_dot_product_attention
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import get_bool_env_var
if TYPE_CHECKING:
@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
metadata = ForwardMetadata()
@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
metadata = self.graph_metadata[bs]

View File

@@ -8,7 +8,7 @@ import torch
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
class AttentionBackend(ABC):
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
"""Init the metadata for a forward pass for capturing a cuda graph."""
raise NotImplementedError()
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Init the metadata for a forward pass for replaying a cuda graph."""

View File

@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda()
if _is_cuda:
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
if spec_info is None:
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):

View File

@@ -11,9 +11,8 @@ import triton.language as tl
from sglang.srt.configs.model_config import AttentionArch
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import SWAKVPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
"""Initialize forward metadata for capturing CUDA graph."""
metadata = FlashAttentionMetadata()
@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
out_cache_loc: Optional[torch.Tensor] = None,
):
@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
# TODO: incrementally update the metadata for the later steps,

View File

@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.layers.radix_attention import AttentionType
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
get_int_env_var,
is_flashinfer_available,
@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
decode_wrappers = []
@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
disable_split_kv: Optional[bool] = None,
):
@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
paged_kernel_lens_sum: int,
kv_indptr: torch.Tensor,
kv_start_idx: torch.Tensor,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
# Keep the signature for type checking. It will be assigned during runtime.
@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
if use_ragged:
@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
use_ragged: bool,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
fixed_split_size: Optional[int] = None,
):
for wrapper_id in range(2):
@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
spec_info: Optional[SpecInput],
use_sliding_window_kv_pool: bool = False,
fixed_split_size: Optional[int] = None,
):
@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
)
assert isinstance(spec_info, SpecInput)
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps
@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()

View File

@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
from sglang.srt.layers.dp_attention import get_attention_tp_size
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import (
is_flashinfer_available,
is_sm100_supported,
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
if is_flashinfer_available():
from flashinfer import (
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
decode_wrapper = BatchMLAPagedAttentionWrapper(
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
if forward_mode.is_decode_or_idle():
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
seq_lens_sum: int,
decode_wrapper: BatchMLAPagedAttentionWrapper,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs,
):
decode_wrapper = decode_wrapper or self.decode_wrapper
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
q_indptr: torch.Tensor,
kv_indptr: torch.Tensor,
init_metadata_replay: bool = False,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
**fast_decode_kwargs,
):
bs = len(req_pool_indices)
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
prefix_lens: torch.Tensor,
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
if use_ragged:
paged_kernel_lens = prefix_lens
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
kv_indptr: torch.Tensor,
qo_indptr: torch.Tensor,
use_ragged: bool,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
bs = len(seq_lens)
sm_scale = self.scaling
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
qo_indptr = qo_indptr[: bs + 1]
custom_mask = None
else:
assert isinstance(spec_info, EagleDraftInput) or isinstance(
spec_info, EagleVerifyInput
)
assert isinstance(spec_info, SpecInput)
# TODO: Support topk > 1 with custom mask
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
if topk > 1:
raise ValueError(
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
)
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
)
def call_fn(i, forward_batch):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
forward_batch.spec_info.kv_indptr = (
forward_batch.spec_info.kv_indptr.clone()
)

View File

@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
# FlashMLA only supports pagesize=64
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):

View File

@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
class HybridAttnBackend(AttentionBackend):
@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
backend = self._select_backend(forward_mode)
backend.init_forward_metadata_capture_cuda_graph(
@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
backend = self._select_backend(forward_mode)

View File

@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.models.qwen3_next import fused_gdn_gating
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import is_cuda, is_npu
if is_cuda():
@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
if forward_mode.is_decode_or_idle():
self.query_start_loc_list[bs - 1].copy_(
@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
num_padding = torch.count_nonzero(
@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
for attn_backend in self.attn_backend_list:
attn_backend.init_forward_metadata_capture_cuda_graph(
@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
for attn_backend in self.attn_backend_list:

View File

@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING, Callable, List, Optional, Union
from typing import TYPE_CHECKING, Callable, List, Optional
import torch
from sglang.srt import two_batch_overlap
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
if TYPE_CHECKING:
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
self.primary.init_forward_metadata_capture_cuda_graph(
bs=bs,
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
self.primary.init_forward_metadata_replay_cuda_graph(
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
# capture args
capture_num_tokens: int = None,
# replay args
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: "ForwardMode",
spec_info: Optional[EagleVerifyInput],
spec_info: Optional[SpecInput],
# capture args
capture_num_tokens: int = None,
# replay args

View File

@@ -22,7 +22,7 @@ from sglang.srt.utils import (
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
def logit_capping_mod(logit_capping_method, logit_cap):
@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
assert encoder_lens is None, "Not supported"
window_kv_indptr = self.window_kv_indptr
@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None
@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
topk: int,
speculative_num_steps: int,
):
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
self.topk = topk
self.speculative_num_steps = speculative_num_steps

View File

@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
if is_flashinfer_available():
import flashinfer
from sglang.srt.speculative.eagle_utils import EagleDraftInput
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
# Constants
DEFAULT_WORKSPACE_SIZE_MB = (
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
"""Initialize metadata for CUDA graph capture."""
metadata = TRTLLMMHAMetadata()
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
forward_batch: ForwardBatch,
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
self, forward_batch: ForwardBatch, bs: int
):
assert forward_batch.spec_info is not None
assert isinstance(forward_batch.spec_info, EagleDraftInput)
assert forward_batch.spec_info.is_draft_input()
for i in range(self.speculative_num_steps - 1):

View File

@@ -30,7 +30,7 @@ if is_flashinfer_available():
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.spec_info import SpecInfo
from sglang.srt.speculative.spec_info import SpecInput
_is_cuda = is_cuda()
@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
):
"""Initialize metadata for CUDA graph capture."""
@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[SpecInfo],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
"""Replay CUDA graph with new inputs."""

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import logging
from dataclasses import dataclass
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Optional
import torch
import triton
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
if TYPE_CHECKING:
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
logger = logging.getLogger(__name__)
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
assert encoder_lens is None, "Not supported"
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
seq_lens_sum: int,
encoder_lens: Optional[torch.Tensor],
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
seq_lens_cpu: Optional[torch.Tensor],
):
# NOTE: encoder_lens expected to be zeros or None

View File

@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
get_moe_expert_parallel_world_size,
get_moe_tensor_parallel_rank,
get_moe_tensor_parallel_world_size,
get_tp_group,
tensor_model_parallel_all_reduce,
)
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
use_symmetric_memory,
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.layers.moe import (
MoeRunnerConfig,
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
should_use_flashinfer_trtllm_moe,
)
from sglang.srt.layers.moe.token_dispatcher.standard import (
CombineInput,
StandardDispatcher,
StandardDispatchOutput,
)

View File

@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
if TYPE_CHECKING:
from sglang.srt.configs.model_config import ModelConfig
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
None
)
# spec_info: Optional[SpecInput] = None
spec_info: Optional[SpecInput] = None
# Whether to return hidden states
return_hidden_states: bool = False
@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
# Speculative decoding
spec_algorithm: SpeculativeAlgorithm = None
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
None
)
spec_info: Optional[SpecInput] = None
# If set, the output of the batch contains the hidden states of the run.
capture_hidden_mode: CaptureHiddenMode = None
hicache_consumer_index: int = -1

View File

@@ -607,7 +607,7 @@ class CPUGraphRunner:
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")

View File

@@ -821,7 +821,7 @@ class CudaGraphRunner:
self.model_runner.spec_algorithm.is_eagle()
or self.model_runner.spec_algorithm.is_standalone()
):
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")

View File

@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
from sglang.srt.utils import (
flatten_nested_list,
get_compiler_backend,
is_npu,
support_triton,
)
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
if TYPE_CHECKING:
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
_is_npu = is_npu()
@@ -293,7 +286,7 @@ class ForwardBatch:
global_forward_mode: Optional[ForwardMode] = None
# Speculative decoding
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
spec_info: Optional[SpecInput] = None
spec_algorithm: SpeculativeAlgorithm = None
capture_hidden_mode: CaptureHiddenMode = None
@@ -364,33 +357,14 @@ class ForwardBatch:
# For MLP sync
if batch.global_num_tokens is not None:
from sglang.srt.speculative.eagle_utils import (
EagleDraftInput,
EagleVerifyInput,
)
assert batch.global_num_tokens_for_logprob is not None
# process global_num_tokens and global_num_tokens_for_logprob
if batch.spec_info is not None:
if isinstance(batch.spec_info, EagleDraftInput):
global_num_tokens = [
x * batch.spec_info.num_tokens_per_batch
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.num_tokens_for_logprob_per_batch
for x in batch.global_num_tokens_for_logprob
]
else:
assert isinstance(batch.spec_info, EagleVerifyInput)
global_num_tokens = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens
]
global_num_tokens_for_logprob = [
x * batch.spec_info.draft_token_num
for x in batch.global_num_tokens_for_logprob
]
spec_info: SpecInput = batch.spec_info
global_num_tokens, global_num_tokens_for_logprob = (
spec_info.get_spec_adjusted_global_num_tokens(batch)
)
else:
global_num_tokens = batch.global_num_tokens
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
@@ -669,9 +643,6 @@ class ForwardBatch:
)
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
from sglang.srt.speculative.eagle_utils import EagleDraftInput
assert self.global_num_tokens_cpu is not None
assert self.global_num_tokens_for_logprob_cpu is not None
@@ -768,7 +739,8 @@ class ForwardBatch:
if self.extend_seq_lens is not None:
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
if self.spec_info is not None and self.spec_info.is_draft_input():
# FIXME(lsyin): remove this isinstance logic
spec_info = self.spec_info
self.output_cache_loc_backup = self.out_cache_loc
self.hidden_states_backup = spec_info.hidden_states

View File

@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,

View File

@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
)
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
from sglang.srt.speculative.eagle_info import EagleDraftInput
from sglang.srt.speculative.spec_utils import fast_topk
from sglang.srt.utils import (
require_attn_tp_gather,
require_gathered_buffer,

View File

@@ -1,235 +1,52 @@
from __future__ import annotations
import copy
import logging
import os
import time
from copy import copy
from dataclasses import dataclass
from typing import List, Optional
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.environ import envs
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
Req,
ScheduleBatch,
get_last_loc,
global_server_args_dict,
)
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
SIMULATE_ACC_LEN,
TREE_SPEC_KERNEL_AVAILABLE,
_generate_simulated_accept_index,
align_evict_mask_to_page_size,
assign_req_to_token_pool,
create_accept_length_filter,
create_extend_after_decode_spec_info,
filter_finished_cache_loc_kernel,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
if is_cuda():
from sgl_kernel import (
fast_topk,
top_k_renorm_prob,
top_p_renorm_prob,
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
elif is_hip():
from sgl_kernel import fast_topk, verify_tree_greedy
from sgl_kernel import verify_tree_greedy
logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@dataclass
class EagleDraftInput:
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
topk_index: torch.Tensor = None
# shape: (b, hidden_size)
hidden_states: torch.Tensor = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
# Inputs for extend
# shape: (b,)
verified_id: torch.Tensor = None
accept_length: torch.Tensor = None
accept_length_cpu: List[int] = None
# Inputs for the attention backends
# shape: (b + 1,)
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
def prepare_for_extend(self, batch: ScheduleBatch):
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
dtype: torch.dtype,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
return cls(
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
)
def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
speculative_num_steps: int,
):
if batch.forward_mode.is_idle():
return
batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False
batch.return_hidden_states = False
self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1)
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if has_been_filtered:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if len(new_indices) != len(self.topk_p):
logger.warning(
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
)
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
self.topk_index = self.topk_index[new_indices]
self.hidden_states = self.hidden_states[new_indices]
self.verified_id = self.verified_id[new_indices]
def merge_batch(self, spec_info: EagleDraftInput):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.topk_p = spec_info.topk_p
self.topk_index = spec_info.topk_index
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclass
class EagleVerifyOutput:
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id: torch.Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accepted_indices: torch.Tensor
@dataclass
class EagleVerifyInput:
class EagleVerifyInput(SpecInput):
draft_token: torch.Tensor
custom_mask: torch.Tensor
positions: torch.Tensor
@@ -245,6 +62,12 @@ class EagleVerifyInput:
seq_lens_cpu: torch.Tensor
grammar: BaseGrammarObject = None
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_VERIFY)
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
@classmethod
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
return cls(
@@ -724,574 +547,184 @@ class EagleVerifyInput:
)
@triton.jit
def create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs_upper: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, bs_upper)
seq_length = tl.load(seq_lens + pid)
accept_length = tl.load(accept_lens + pid)
@dataclass
class EagleDraftInput(SpecInput):
# The inputs for decode
# shape: (b, topk)
topk_p: torch.Tensor = None
topk_index: torch.Tensor = None
# shape: (b, hidden_size)
hidden_states: torch.Tensor = None
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
accept_len_cumsum = tl.sum(
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
)
positions_ptr = positions + accept_len_cumsum
mask = offsets < accept_length
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
# Inputs for extend
# shape: (b,)
verified_id: torch.Tensor = None
accept_length: torch.Tensor = None
accept_length_cpu: List[int] = None
accept_len_cumsum += accept_length - 1
verified_id_data = tl.load(verified_id + accept_len_cumsum)
tl.store(new_verified_id + pid, verified_id_data)
# Inputs for the attention backends
# shape: (b + 1,)
kv_indptr: torch.Tensor = None
kv_indices: torch.Tensor = None
# Shape info for padding
num_tokens_per_batch: int = -1
num_tokens_for_logprob_per_batch: int = -1
@triton.jit
def assign_req_to_token_pool(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(start_offset + pid)
kv_end = tl.load(end_offset + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
# Inputs for draft extend
# shape: (b,)
seq_lens_for_draft_extend: torch.Tensor = None
req_pool_indices_for_draft_extend: torch.Tensor = None
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0)
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
out_cache_ptr = out_cache_loc + out_offset
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
load_offset = tl.arange(0, BLOCK_SIZE)
def prepare_for_extend(self, batch: ScheduleBatch):
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
save_offset += BLOCK_SIZE
load_offset += BLOCK_SIZE
if batch.forward_mode.is_idle():
return
# Prefill only generate 1 token.
assert len(self.verified_id) == len(batch.seq_lens)
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
pid = tl.program_id(axis=0)
pt = 0
for i, extend_len in enumerate(batch.extend_lens):
input_ids = batch.input_ids[pt : pt + extend_len]
batch.input_ids[pt : pt + extend_len] = torch.cat(
(input_ids[1:], self.verified_id[i].reshape(1))
)
pt += extend_len
if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk):
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
@triton.jit
def generate_draft_decode_kv_indices(
req_pool_indices,
req_to_token,
paged_kernel_lens,
kv_indices,
kv_indptr,
positions,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
kv_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for _ in range(num_loop):
mask = kv_offset < seq_len
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
tl.store(kv_ptr + kv_offset, data, mask=mask)
kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
mask=extend_offset < iters,
)
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr
bs_offset = tl.arange(0, num_tokens_upper)
zid = bid * topk + topk_id
if zid == 0:
zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
)
num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
@torch.compile(dynamic=True)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid)
copy_offset = tl.arange(0, num_verify_tokens_upper)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
@torch.compile(dynamic=True)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
seq_lens: torch.Tensor,
):
accept_length_filter = torch.zeros_like(accept_length)
accept_length_filter[unfinished_index_device] = (
accept_length[unfinished_index_device] + 1
)
seq_lens.add_(accept_length + 1)
return accept_length_filter
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
)
return input_ids, hidden_states, scores, tree_info
def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
bs,
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
assert simulate_acc_len > 0.0
if simulate_acc_method == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0)
def dfs(
curr: int,
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
parent_pos: int,
@classmethod
def create_idle_input(
cls,
device: torch.device,
hidden_size: int,
dtype: torch.dtype,
topk: int,
capture_hidden_mode: CaptureHiddenMode,
):
if curr == 0:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
return cls(
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
capture_hidden_mode=capture_hidden_mode,
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
accept_length_cpu=[],
)
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
def prepare_extend_after_decode(
self,
batch: ScheduleBatch,
speculative_num_steps: int,
):
if curr != 0:
# Rollback the current token
grammar.rollback(1)
if batch.forward_mode.is_idle():
return
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
batch.input_ids = self.verified_id
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
batch.extend_num_tokens = sum(batch.extend_lens)
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
batch.return_logprob = False
batch.return_hidden_states = False
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
self.capture_hidden_mode = CaptureHiddenMode.LAST
self.accept_length.add_(1)
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_token_bitmask(
reqs: List[Req],
verify_input: EagleVerifyInput,
retrieve_next_token_cpu: torch.Tensor,
retrieve_next_sibling_cpu: torch.Tensor,
draft_tokens_cpu: torch.Tensor,
vocab_size: int,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
2. if so, what is the corresponding logit mask.
"""
def generate_attn_arg_prefill(
self,
req_pool_indices: torch.Tensor,
paged_kernel_lens: torch.Tensor,
paged_kernel_lens_sum: int,
req_to_token: torch.Tensor,
):
bs = self.accept_length.numel()
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
num_draft_tokens = draft_tokens_cpu.shape[-1]
if paged_kernel_lens_sum is None:
paged_kernel_lens_sum = cum_kv_seq_len[-1]
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
s = time.perf_counter()
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
kv_indices = torch.empty(
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
)
create_flashinfer_kv_indices_triton[(bs,)](
req_to_token,
req_pool_indices,
paged_kernel_lens,
cum_kv_seq_len,
None,
kv_indices,
req_to_token.size(1),
)
return kv_indices, cum_kv_seq_len, qo_indptr, None
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
if has_been_filtered:
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
# therefore, we don't need to filter the batch again in scheduler
if len(new_indices) != len(self.topk_p):
logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with "
f"grammar: {req.grammar}"
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
)
self.topk_p = self.topk_p[: len(new_indices)]
self.topk_index = self.topk_index[: len(new_indices)]
self.hidden_states = self.hidden_states[: len(new_indices)]
self.verified_id = self.verified_id[: len(new_indices)]
else:
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
self.topk_p = self.topk_p[new_indices]
self.topk_index = self.topk_index[new_indices]
self.hidden_states = self.hidden_states[new_indices]
self.verified_id = self.verified_id[new_indices]
verify_input.grammar = grammar
return allocate_token_bitmask
def merge_batch(self, spec_info: "EagleDraftInput"):
if self.hidden_states is None:
self.hidden_states = spec_info.hidden_states
self.verified_id = spec_info.verified_id
self.topk_p = spec_info.topk_p
self.topk_index = spec_info.topk_index
return
if spec_info.hidden_states is None:
return
self.hidden_states = torch.cat(
[self.hidden_states, spec_info.hidden_states], axis=0
)
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
@dataclass
class EagleVerifyOutput:
# Draft input batch
draft_input: EagleDraftInput
# Logit outputs from target worker
logits_output: LogitsProcessorOutput
# Accepted token ids including the bonus token
verified_id: torch.Tensor
# Accepted token length per sequence in a batch in CPU.
accept_length_per_req_cpu: List[int]
# Accepted indices from logits_output.next_token_logits
accepted_indices: torch.Tensor

View File

@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
EAGLEDraftExtendCudaGraphRunner,
)
from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.eagle_info import (
EagleDraftInput,
EagleVerifyInput,
EagleVerifyOutput,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.speculative.spec_utils import (
assign_draft_cache_locs,
fast_topk,
generate_token_bitmask,
select_top_k_tokens,
)
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import copy
import logging
from typing import Optional
from typing import Optional, Tuple
import torch
import triton
@@ -13,6 +13,7 @@ from dataclasses import dataclass
import torch.nn.functional as F
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import apply_custom_logit_processor
from sglang.srt.managers.schedule_batch import (
@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
global_server_args_dict,
)
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.speculative.eagle_utils import (
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
from sglang.srt.speculative.spec_utils import (
TREE_SPEC_KERNEL_AVAILABLE,
assign_req_to_token_pool,
create_flashinfer_kv_indices_triton,
get_src_tgt_cache_loc,
get_target_cache_loc,
)
@@ -42,7 +43,7 @@ elif is_hip():
@dataclass
class NgramVerifyInput:
class NgramVerifyInput(SpecInput):
def __init__(
self,
draft_token: torch.Tensor,
@@ -53,6 +54,7 @@ class NgramVerifyInput:
retrive_next_sibling: torch.Tensor,
draft_token_num: int,
):
super().__init__(SpecInputType.NGRAM_VERIFY)
self.draft_token = draft_token
self.custom_mask = tree_mask
self.positions = positions
@@ -62,6 +64,9 @@ class NgramVerifyInput:
self.draft_token_num = draft_token_num
self.device = self.custom_mask.device
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
return self.draft_token_num, self.draft_token_num
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
if batch.forward_mode.is_idle():
return

View File

@@ -1,8 +1,5 @@
import logging
import os
import threading
import time
from typing import TYPE_CHECKING, List, Optional, Union
from typing import List, Optional
import numpy as np
import torch
@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
from sglang.srt.utils import broadcast_pyobj
logger = logging.getLogger(__name__)

View File

@@ -1,4 +1,8 @@
from abc import ABC, abstractmethod
from enum import IntEnum, auto
from typing import List, Tuple
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
class SpeculativeAlgorithm(IntEnum):
@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
if name is not None:
name = name.upper()
return name_map[name]
class SpecInputType(IntEnum):
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
EAGLE_DRAFT = auto()
EAGLE_VERIFY = auto()
NGRAM_VERIFY = auto()
class SpecInput(ABC):
def __init__(self, spec_input_type: SpecInputType):
self.spec_input_type = spec_input_type
def is_draft_input(self) -> bool:
# FIXME: remove this function which is only used for assertion
# or use another variable name like `draft_input` to substitute `spec_info`
return self.spec_input_type == SpecInputType.EAGLE_DRAFT
def is_verify_input(self) -> bool:
return self.spec_input_type in {
SpecInputType.EAGLE_VERIFY,
SpecInputType.NGRAM_VERIFY,
}
@abstractmethod
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
pass
def get_spec_adjusted_global_num_tokens(
self, forward_batch: ModelWorkerBatch
) -> Tuple[List[int], List[int]]:
c1, c2 = self.get_spec_adjust_token_coefficient()
global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
global_num_tokens_for_logprob = [
x * c2 for x in forward_batch.global_num_tokens_for_logprob
]
return global_num_tokens, global_num_tokens_for_logprob

View File

@@ -0,0 +1,607 @@
from __future__ import annotations
import logging
import os
import time
from typing import TYPE_CHECKING, List
import torch
import triton
import triton.language as tl
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
from sglang.srt.environ import envs
from sglang.srt.managers.schedule_batch import Req
from sglang.srt.utils import is_cuda, is_hip
if is_cuda():
from sgl_kernel import fast_topk
elif is_hip():
from sgl_kernel import fast_topk
if TYPE_CHECKING:
from sglang.srt.speculative.eagle_info import EagleVerifyInput
logger = logging.getLogger(__name__)
# Simulate acceptance length for benchmarking purposes
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
@triton.jit
def create_extend_after_decode_spec_info(
verified_id,
seq_lens,
accept_lens,
positions,
new_verified_id,
bs_upper: tl.constexpr,
):
pid = tl.program_id(axis=0)
offsets = tl.arange(0, bs_upper)
seq_length = tl.load(seq_lens + pid)
accept_length = tl.load(accept_lens + pid)
accept_len_cumsum = tl.sum(
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
)
positions_ptr = positions + accept_len_cumsum
mask = offsets < accept_length
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
accept_len_cumsum += accept_length - 1
verified_id_data = tl.load(verified_id + accept_len_cumsum)
tl.store(new_verified_id + pid, verified_id_data)
@triton.jit
def assign_req_to_token_pool(
req_pool_indices,
req_to_token,
start_offset,
end_offset,
out_cache_loc,
pool_len: tl.constexpr,
bs_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 32
pid = tl.program_id(axis=0)
kv_start = tl.load(start_offset + pid)
kv_end = tl.load(end_offset + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
length_offset = tl.arange(0, bs_upper)
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
out_offset = tl.sum(end - start, axis=0)
out_cache_ptr = out_cache_loc + out_offset
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
load_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
for _ in range(num_loop):
mask = save_offset < kv_end
data = tl.load(out_cache_ptr + load_offset, mask=mask)
tl.store(token_pool + save_offset, data, mask=mask)
save_offset += BLOCK_SIZE
load_offset += BLOCK_SIZE
@triton.jit
def assign_draft_cache_locs(
req_pool_indices,
req_to_token,
seq_lens,
extend_lens,
num_new_pages_per_topk,
out_cache_loc,
pool_len: tl.constexpr,
topk: tl.constexpr,
speculative_num_steps: tl.constexpr,
page_size: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
pid = tl.program_id(axis=0)
if page_size == 1 or topk == 1:
copy_len = topk * speculative_num_steps
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
else:
bs_offset = tl.arange(0, bs_upper)
copy_len = tl.load(extend_lens + pid)
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
out_cache_ptr = out_cache_loc + cum_copy_len
# Part 1: Copy from out_cache_loc to req_to_token
kv_start = tl.load(seq_lens + pid)
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
for i in range(num_loop):
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
mask = copy_offset < copy_len
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
if page_size == 1 or topk == 1:
return
# Part 2: Copy the indices for the last partial page
prefix_len = tl.load(seq_lens + pid)
last_page_len = prefix_len % page_size
offsets = tl.arange(0, page_size)
mask = offsets < last_page_len
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
prefix_base = token_pool + prefix_len - last_page_len
for topk_id in range(topk):
value = tl.load(prefix_base + offsets, mask=mask)
tl.store(
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
value,
mask=mask,
)
# Part 3: Remove the padding in out_cache_loc
iter_offest = tl.arange(0, iter_upper)
for topk_id in range(topk):
indices = tl.load(
prefix_base
+ topk_id * num_new_pages_per_topk_ * page_size
+ last_page_len
+ iter_offest,
mask=iter_offest < speculative_num_steps,
)
tl.store(
out_cache_loc
+ pid * topk * speculative_num_steps
+ topk_id * speculative_num_steps
+ iter_offest,
indices,
mask=iter_offest < speculative_num_steps,
)
@triton.jit
def generate_draft_decode_kv_indices(
req_pool_indices,
req_to_token,
paged_kernel_lens,
kv_indices,
kv_indptr,
positions,
pool_len: tl.constexpr,
kv_indices_stride: tl.constexpr,
kv_indptr_stride: tl.constexpr,
bs_upper: tl.constexpr,
iter_upper: tl.constexpr,
num_tokens_upper: tl.constexpr,
page_size: tl.constexpr,
):
BLOCK_SIZE: tl.constexpr = 128
iters = tl.program_id(axis=0)
bid = tl.program_id(axis=1)
topk_id = tl.program_id(axis=2)
num_steps = tl.num_programs(axis=0)
num_seqs = tl.num_programs(axis=1)
topk = tl.num_programs(axis=2)
kv_indices += kv_indices_stride * iters
kv_indptr += kv_indptr_stride * iters
iters += 1
load_offset = tl.arange(0, bs_upper)
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
seq_len = tl.load(paged_kernel_lens + bid)
cum_seq_len = tl.sum(seq_lens)
# Update kv_indices
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
kv_ptr = kv_indices + kv_offset
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
kv_offset = tl.arange(0, BLOCK_SIZE)
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
for _ in range(num_loop):
mask = kv_offset < seq_len
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
tl.store(kv_ptr + kv_offset, data, mask=mask)
kv_offset += BLOCK_SIZE
extend_offset = tl.arange(0, iter_upper)
if page_size == 1 or topk == 1:
extend_data = tl.load(
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
mask=extend_offset < iters,
)
else:
prefix_len = seq_len
last_page_len = prefix_len % page_size
num_new_pages_per_topk = (
last_page_len + num_steps + page_size - 1
) // page_size
prefix_base = seq_len // page_size * page_size
start = (
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
)
extend_data = tl.load(
token_pool_ptr + start + extend_offset,
mask=extend_offset < iters,
)
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
# Update kv_indptr
bs_offset = tl.arange(0, num_tokens_upper)
zid = bid * topk + topk_id
if zid == 0:
zid = num_seqs * topk
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
base = tl.sum(positions)
tl.store(kv_indptr + zid, base + zid * iters)
@triton.jit
def align_evict_mask_to_page_size(
seq_lens,
evict_mask,
page_size: tl.constexpr,
num_draft_tokens: tl.constexpr,
BLOCK_SIZE: tl.constexpr,
):
t_range = tl.arange(0, BLOCK_SIZE)
bid = tl.program_id(axis=0)
seq_len = tl.load(seq_lens + bid)
io_mask = t_range < num_draft_tokens
mask_row = tl.load(
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
)
num_trues = tl.sum(mask_row)
num_false = num_draft_tokens - num_trues
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
tl.store(evict_mask + bid * num_draft_tokens + i, False)
@triton.jit
def get_target_cache_loc(
tgt_cache_loc,
to_free_slots,
accept_length,
to_free_num_slots,
out_cache_loc,
num_verify_tokens: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
bs_upper: tl.constexpr,
):
bid = tl.program_id(axis=0)
offset = tl.arange(0, num_verify_tokens_upper)
bs_offset = tl.arange(0, bs_upper)
# write the first part to tgt_cache_loc
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
copy_len = tl.load(accept_length + bid) + 1
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
)
tl.store(
tgt_cache_loc + tgt_cache_loc_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
# write the second part to to_free_num_pages
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
to_free_slots_start = tl.sum(to_free_num_slots_all)
copy_len = to_free_num_slots_cur
out_cache_loc_row = tl.load(
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
mask=offset < copy_len,
)
tl.store(
to_free_slots + to_free_slots_start + offset,
out_cache_loc_row,
mask=offset < copy_len,
)
@torch.compile(dynamic=True)
def get_src_tgt_cache_loc(
seq_lens: torch.Tensor,
out_cache_loc: torch.Tensor,
accept_index: torch.Tensor,
accept_length: torch.Tensor,
draft_token_num: int,
page_size: int,
):
src_cache_loc = out_cache_loc[accept_index]
tgt_cache_loc = torch.empty_like(src_cache_loc)
extended_len = seq_lens + draft_token_num
keep_len = torch.minimum(
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
extended_len,
)
to_free_num_slots = extended_len - keep_len
return src_cache_loc, tgt_cache_loc, to_free_num_slots
@triton.jit
def filter_finished_cache_loc_kernel(
out_cache_loc,
tgt_cache_loc,
accept_length,
accept_length_filter,
bs_upper: tl.constexpr,
num_verify_tokens_upper: tl.constexpr,
):
bid = tl.program_id(0)
bs_offset = tl.arange(0, bs_upper)
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
old_start = tl.sum(accept_length_all) + bid
accept_length_filter_all = tl.load(
accept_length_filter + bs_offset, mask=bs_offset < bid
)
new_start = tl.sum(accept_length_filter_all)
copy_len = tl.load(accept_length_filter + bid)
copy_offset = tl.arange(0, num_verify_tokens_upper)
value = tl.load(
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
)
tl.store(
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
)
@torch.compile(dynamic=True)
def create_accept_length_filter(
accept_length: torch.Tensor,
unfinished_index_device: torch.Tensor,
seq_lens: torch.Tensor,
):
accept_length_filter = torch.zeros_like(accept_length)
accept_length_filter[unfinished_index_device] = (
accept_length[unfinished_index_device] + 1
)
seq_lens.add_(accept_length + 1)
return accept_length_filter
@torch.compile(dynamic=True)
def select_top_k_tokens(
i: int,
topk_p: torch.Tensor,
topk_index: torch.Tensor,
hidden_states: torch.Tensor,
scores: torch.Tensor,
topk: int,
):
if i == 0:
# The first step after extend
input_ids = topk_index.flatten()
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
scores = topk_p # shape: (b, topk)
tree_info = (
topk_p.unsqueeze(1), # shape: (b, 1, topk)
topk_index, # shape: (b, topk)
torch.arange(-1, topk, dtype=torch.long, device="cuda")
.unsqueeze(0)
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
)
else:
# The later decode steps
expand_scores = torch.mul(
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
topk_cs_p, topk_cs_index = fast_topk(
expand_scores.flatten(start_dim=1), topk, dim=-1
) # (b, topk)
scores = topk_cs_p # shape: (b, topk)
topk_index = topk_index.reshape(-1, topk**2)
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
if hidden_states.shape[0] > 0:
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
0, hidden_states.shape[0], step=topk, device="cuda"
).repeat_interleave(topk)
hidden_states = hidden_states[selected_input_index, :]
tree_info = (
expand_scores, # shape: (b, topk, topk)
topk_index, # shape: (b, topk * topk)
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
)
return input_ids, hidden_states, scores, tree_info
def _generate_simulated_accept_index(
accept_index,
predict,
accept_length,
bs,
spec_steps,
simulate_acc_len: float = SIMULATE_ACC_LEN,
simulate_acc_method: str = SIMULATE_ACC_METHOD,
):
assert simulate_acc_len > 0.0
if simulate_acc_method == "multinomial":
simulated_values = torch.normal(
mean=simulate_acc_len,
std=1.0,
size=(1,),
device="cpu",
)
# clamp simulated values to be between 1 and self.spec_steps
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
simulate_acc_len = int(simulated_values.round().item())
elif simulate_acc_method == "match-expected":
# multinomial sampling does not match the expected length
# we keep it for the sake of compatibility of existing tests
# but it's better to use "match-expected" for the cases that need to
# match the expected length, One caveat is that this will only sample
# either round down or round up of the expected length
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
lower = int(simulate_acc_len // 1)
upper = lower + 1 if lower < spec_steps + 1 else lower
if lower == upper:
simulate_acc_len = lower
else:
weight_upper = simulate_acc_len - lower
weight_lower = 1.0 - weight_upper
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
sampled_index = torch.multinomial(probs, num_samples=1)
simulate_acc_len = lower if sampled_index == 0 else upper
else:
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
sim_accept_index = torch.full(
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
)
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
simulate_acc_len, device=accept_index.device
)
accept_length.fill_(simulate_acc_len - 1)
predict.fill_(100) # some legit token id
return sim_accept_index
def traverse_tree(
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
draft_tokens: torch.Tensor,
grammar: BaseGrammarObject,
allocate_token_bitmask: torch.Tensor,
):
"""
Traverse the tree constructed by the draft model to generate the logits mask.
"""
assert (
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
)
allocate_token_bitmask.fill_(0)
def dfs(
curr: int,
retrieve_next_token: torch.Tensor,
retrieve_next_sibling: torch.Tensor,
parent_pos: int,
):
if curr == 0:
# the first token generated by the target model, and thus it is always
# accepted from the previous iteration
accepted = True
else:
parent_bitmask = allocate_token_bitmask[parent_pos]
curr_token_id = draft_tokens[curr]
# 32 boolean bitmask values are packed into 32-bit integers
accepted = (
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
) != 0
if accepted:
if curr != 0:
# Accept the current token
grammar.accept_token(draft_tokens[curr])
if not grammar.is_terminated():
# Generate the bitmask for the current token
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
if retrieve_next_token[curr] != -1:
# Visit the child node
dfs(
retrieve_next_token[curr],
retrieve_next_token,
retrieve_next_sibling,
curr,
)
if curr != 0:
# Rollback the current token
grammar.rollback(1)
if retrieve_next_sibling[curr] != -1:
# Visit the sibling node
dfs(
retrieve_next_sibling[curr],
retrieve_next_token,
retrieve_next_sibling,
parent_pos,
)
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
def generate_token_bitmask(
reqs: List[Req],
verify_input: EagleVerifyInput,
retrieve_next_token_cpu: torch.Tensor,
retrieve_next_sibling_cpu: torch.Tensor,
draft_tokens_cpu: torch.Tensor,
vocab_size: int,
):
"""
Generate the logit mask for structured output.
Draft model's token can be either valid or invalid with respect to the grammar.
We need to perform DFS to
1. figure out which tokens are accepted by the grammar.
2. if so, what is the corresponding logit mask.
"""
num_draft_tokens = draft_tokens_cpu.shape[-1]
allocate_token_bitmask = None
assert len(reqs) == retrieve_next_token_cpu.shape[0]
grammar = None
for i, req in enumerate(reqs):
if req.grammar is not None:
if allocate_token_bitmask is None:
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
vocab_size=vocab_size,
batch_size=draft_tokens_cpu.numel(),
device="cpu",
)
grammar = req.grammar
s = time.perf_counter()
traverse_tree(
retrieve_next_token_cpu[i],
retrieve_next_sibling_cpu[i],
draft_tokens_cpu[i],
req.grammar,
allocate_token_bitmask[
i * num_draft_tokens : (i + 1) * num_draft_tokens
],
)
tree_traverse_time = time.perf_counter() - s
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
logger.warning(
f"Bit mask generation took {tree_traverse_time} seconds with "
f"grammar: {req.grammar}"
)
verify_input.grammar = grammar
return allocate_token_bitmask

View File

@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
)
from sglang.srt.operations import execute_operations, execute_overlapped_operations
from sglang.srt.operations_strategy import OperationsStrategy
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
from sglang.srt.speculative.spec_info import SpecInput
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
if TYPE_CHECKING:
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
def get_token_num_per_seq(
forward_mode: ForwardMode,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
spec_info: Optional[SpecInput] = None,
):
if forward_mode.is_target_verify():
return spec_info.draft_token_num
@@ -273,7 +274,7 @@ def compute_split_token_index(
def compute_split_indices_for_cuda_graph_replay(
forward_mode: ForwardMode,
cuda_graph_num_tokens: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
forward_mode_for_tbo_split = (
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
forward_mode: ForwardMode,
bs: int,
num_token_non_padded: int,
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
spec_info: Optional[SpecInput],
):
token_num_per_seq = get_token_num_per_seq(
forward_mode=forward_mode, spec_info=spec_info