diff --git a/python/sglang/srt/constrained/outlines_jump_forward.py b/python/sglang/srt/constrained/outlines_jump_forward.py index cfc65f75f..8e19742c6 100644 --- a/python/sglang/srt/constrained/outlines_jump_forward.py +++ b/python/sglang/srt/constrained/outlines_jump_forward.py @@ -37,7 +37,7 @@ except ImportError: IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)" -# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__ +# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__ DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true") logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py index e2ae55780..277c84f9d 100644 --- a/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py +++ b/python/sglang/srt/disaggregation/decode_schedule_batch_mixin.py @@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin: hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device) # local import to avoid circular import - from sglang.srt.speculative.eagle_utils import EagleDraftInput + from sglang.srt.speculative.eagle_info import EagleDraftInput spec_info = EagleDraftInput( topk_p=topk_p, diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index f1b2da5f8..30901805d 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -4,18 +4,13 @@ from __future__ import annotations end to end attention solution with aiter kernels """ -import math -import os from dataclasses import dataclass from enum import Enum, auto -from functools import partial -from typing import TYPE_CHECKING, List, Optional, Union +from typing import TYPE_CHECKING, Optional import torch import triton -import triton.language as tl -from sglang.global_config import global_config from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.dp_attention import ( @@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput try: from aiter import ( @@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): qo_indptr = None @@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): @@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill: seq_lens_sum: int, prefix_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill: seq_lens_sum: int, prefix_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): kv_start_idx = None @@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill: extend_lens: torch.Tensor, max_q_len: int, max_kv_len: int, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): # Keep the signature for type checking. It will be assigned during runtime. raise NotImplementedError() @@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill: extend_lens: torch.Tensor, max_q_len: int, max_kv_len: int, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): bs = len(req_pool_indices) @@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices self.topk = topk self.speculative_num_steps = speculative_num_steps diff --git a/python/sglang/srt/layers/attention/ascend_backend.py b/python/sglang/srt/layers/attention/ascend_backend.py index 52192b7bc..2391e8664 100644 --- a/python/sglang/srt/layers/attention/ascend_backend.py +++ b/python/sglang/srt/layers/attention/ascend_backend.py @@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional import torch import torch_npu -from torch.nn.functional import scaled_dot_product_attention from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType -from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import get_bool_env_var if TYPE_CHECKING: @@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): metadata = ForwardMetadata() @@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): metadata = self.graph_metadata[bs] diff --git a/python/sglang/srt/layers/attention/base_attn_backend.py b/python/sglang/srt/layers/attention/base_attn_backend.py index 3025d0b11..5b0377fcd 100644 --- a/python/sglang/srt/layers/attention/base_attn_backend.py +++ b/python/sglang/srt/layers/attention/base_attn_backend.py @@ -8,7 +8,7 @@ import torch if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpecInput class AttentionBackend(ABC): @@ -31,7 +31,7 @@ class AttentionBackend(ABC): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): """Init the metadata for a forward pass for capturing a cuda graph.""" raise NotImplementedError() @@ -44,7 +44,7 @@ class AttentionBackend(ABC): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): """Init the metadata for a forward pass for replaying a cuda graph.""" diff --git a/python/sglang/srt/layers/attention/cutlass_mla_backend.py b/python/sglang/srt/layers/attention/cutlass_mla_backend.py index eb0cae262..e81e761bc 100644 --- a/python/sglang/srt/layers/attention/cutlass_mla_backend.py +++ b/python/sglang/srt/layers/attention/cutlass_mla_backend.py @@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput _is_cuda = is_cuda() if _is_cuda: @@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): if spec_info is None: @@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 67cad8d23..1deb9033c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -11,9 +11,8 @@ import triton.language as tl from sglang.srt.configs.model_config import AttentionArch from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.managers.schedule_batch import global_server_args_dict -from sglang.srt.mem_cache.memory_pool import SWAKVPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention @@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): """Initialize forward metadata for capturing CUDA graph.""" metadata = FlashAttentionMetadata() @@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], out_cache_loc: Optional[torch.Tensor] = None, ): @@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend: forward_batch: ForwardBatch, ): assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_forward_metadata_capture_cuda_graph( @@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend: self, forward_batch: ForwardBatch, bs: int ): assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() for i in range(self.speculative_num_steps - 1): # TODO: incrementally update the metadata for the later steps, diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 2b69d734c..5f2b946f3 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput -from sglang.srt.speculative.ngram_utils import NgramVerifyInput +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ( get_int_env_var, is_flashinfer_available, @@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): decode_wrappers = [] @@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): @@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode: seq_lens_sum: int, decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper], encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, disable_split_kv: Optional[bool] = None, ): @@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode: paged_kernel_lens_sum: int, kv_indptr: torch.Tensor, kv_start_idx: torch.Tensor, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, @@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, ): # Keep the signature for type checking. It will be assigned during runtime. @@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, ): if use_ragged: @@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): @@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill: prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper], use_ragged: bool, encoder_lens: Optional[torch.Tensor], - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], fixed_split_size: Optional[int] = None, ): for wrapper_id in range(2): @@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]], + spec_info: Optional[SpecInput], use_sliding_window_kv_pool: bool = False, fixed_split_size: Optional[int] = None, ): @@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill: qo_indptr = qo_indptr[: bs + 1] custom_mask = None else: - assert isinstance( - spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput) - ) + assert isinstance(spec_info, SpecInput) kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( req_pool_indices, @@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices self.topk = topk self.speculative_num_steps = speculative_num_steps @@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend: ) assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() # Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan. indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu() diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 05e9bef80..e785b6013 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import ( from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import ( is_flashinfer_available, is_sm100_supported, @@ -40,7 +40,7 @@ from sglang.srt.utils import ( if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput if is_flashinfer_available(): from flashinfer import ( @@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): decode_wrapper = BatchMLAPagedAttentionWrapper( @@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): if forward_mode.is_decode_or_idle(): @@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode: seq_lens_sum: int, decode_wrapper: BatchMLAPagedAttentionWrapper, init_metadata_replay: bool = False, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, + spec_info: Optional[SpecInput] = None, **fast_decode_kwargs, ): decode_wrapper = decode_wrapper or self.decode_wrapper @@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode: q_indptr: torch.Tensor, kv_indptr: torch.Tensor, init_metadata_replay: bool = False, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, + spec_info: Optional[SpecInput] = None, **fast_decode_kwargs, ): bs = len(req_pool_indices) @@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill: prefix_lens: torch.Tensor, prefill_wrapper_paged: BatchMLAPagedAttentionWrapper, use_ragged: bool, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, + spec_info: Optional[SpecInput] = None, ): if use_ragged: paged_kernel_lens = prefix_lens @@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill: kv_indptr: torch.Tensor, qo_indptr: torch.Tensor, use_ragged: bool, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, + spec_info: Optional[SpecInput] = None, ): bs = len(seq_lens) sm_scale = self.scaling @@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill: qo_indptr = qo_indptr[: bs + 1] custom_mask = None else: - assert isinstance(spec_info, EagleDraftInput) or isinstance( - spec_info, EagleVerifyInput - ) + assert isinstance(spec_info, SpecInput) # TODO: Support topk > 1 with custom mask kv_indices, kv_indptr, qo_indptr, custom_mask = ( spec_info.generate_attn_arg_prefill( @@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices if topk > 1: raise ValueError( @@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend: ) assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] @@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend: ) def call_fn(i, forward_batch): - assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) forward_batch.spec_info.kv_indptr = ( forward_batch.spec_info.kv_indptr.clone() ) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index d1acb1a58..134380f12 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput # FlashMLA only supports pagesize=64 @@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE) @@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index ec40100d1..f7f2c2193 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput class HybridAttnBackend(AttentionBackend): @@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): backend = self._select_backend(forward_mode) backend.init_forward_metadata_capture_cuda_graph( @@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): backend = self._select_backend(forward_mode) diff --git a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py index 86dd77f37..435844f74 100644 --- a/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_linear_attn_backend.py @@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode from sglang.srt.model_executor.model_runner import ModelRunner -from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.models.qwen3_next import fused_gdn_gating +from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import is_cuda, is_npu if is_cuda(): @@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): if forward_mode.is_decode_or_idle(): self.query_start_loc_list[bs - 1].copy_( @@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): num_padding = torch.count_nonzero( @@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): for attn_backend in self.attn_backend_list: attn_backend.init_forward_metadata_capture_cuda_graph( @@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): for attn_backend in self.attn_backend_list: diff --git a/python/sglang/srt/layers/attention/tbo_backend.py b/python/sglang/srt/layers/attention/tbo_backend.py index 06cfbd4ef..bdecfb380 100644 --- a/python/sglang/srt/layers/attention/tbo_backend.py +++ b/python/sglang/srt/layers/attention/tbo_backend.py @@ -1,10 +1,10 @@ -from typing import TYPE_CHECKING, Callable, List, Optional, Union +from typing import TYPE_CHECKING, Callable, List, Optional import torch from sglang.srt import two_batch_overlap from sglang.srt.layers.attention.base_attn_backend import AttentionBackend -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput if TYPE_CHECKING: from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: "ForwardMode", - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): self.primary.init_forward_metadata_capture_cuda_graph( bs=bs, @@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: "ForwardMode", - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): self.primary.init_forward_metadata_replay_cuda_graph( @@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: "ForwardMode", - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], # capture args capture_num_tokens: int = None, # replay args @@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split( seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: "ForwardMode", - spec_info: Optional[EagleVerifyInput], + spec_info: Optional[SpecInput], # capture args capture_num_tokens: int = None, # replay args diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 55b5c6e54..70e99c31f 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -22,7 +22,7 @@ from sglang.srt.utils import ( if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpecInput def logit_capping_mod(logit_capping_method, logit_cap): @@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): assert encoder_lens is None, "Not supported" window_kv_indptr = self.window_kv_indptr @@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): # NOTE: encoder_lens expected to be zeros or None @@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices + from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices self.topk = topk self.speculative_num_steps = speculative_num_steps diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index a48cc9794..454a388f9 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available if is_flashinfer_available(): import flashinfer -from sglang.srt.speculative.eagle_utils import EagleDraftInput - if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput # Constants DEFAULT_WORKSPACE_SIZE_MB = ( @@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): """Initialize metadata for CUDA graph capture.""" metadata = TRTLLMMHAMetadata() @@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" @@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): forward_batch: ForwardBatch, ): assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_forward_metadata_capture_cuda_graph( @@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): self, forward_batch: ForwardBatch, bs: int ): assert forward_batch.spec_info is not None - assert isinstance(forward_batch.spec_info, EagleDraftInput) + assert forward_batch.spec_info.is_draft_input() for i in range(self.speculative_num_steps - 1): diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 7a3f31128..97dce19fd 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -30,7 +30,7 @@ if is_flashinfer_available(): if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.spec_info import SpecInfo + from sglang.srt.speculative.spec_info import SpecInput _is_cuda = is_cuda() @@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], ): """Initialize metadata for CUDA graph capture.""" @@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[SpecInfo], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): """Replay CUDA graph with new inputs.""" diff --git a/python/sglang/srt/layers/attention/wave_backend.py b/python/sglang/srt/layers/attention/wave_backend.py index eb6e061ac..9669a4568 100644 --- a/python/sglang/srt/layers/attention/wave_backend.py +++ b/python/sglang/srt/layers/attention/wave_backend.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, Optional import torch import triton @@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count if TYPE_CHECKING: from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.model_executor.model_runner import ModelRunner - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput + from sglang.srt.speculative.spec_info import SpecInput logger = logging.getLogger(__name__) @@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend): seq_lens: torch.Tensor, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): assert encoder_lens is None, "Not supported" @@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend): seq_lens_sum: int, encoder_lens: Optional[torch.Tensor], forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], seq_lens_cpu: Optional[torch.Tensor], ): # NOTE: encoder_lens expected to be zeros or None diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py index acddcc652..27d41184e 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/layer.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/layer.py @@ -11,12 +11,8 @@ from sglang.srt.distributed import ( get_moe_expert_parallel_world_size, get_moe_tensor_parallel_rank, get_moe_tensor_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce, ) -from sglang.srt.distributed.device_communicators.pynccl_allocator import ( - use_symmetric_memory, -) from sglang.srt.eplb.expert_location import get_global_expert_location_metadata from sglang.srt.layers.moe import ( MoeRunnerConfig, @@ -24,7 +20,6 @@ from sglang.srt.layers.moe import ( should_use_flashinfer_trtllm_moe, ) from sglang.srt.layers.moe.token_dispatcher.standard import ( - CombineInput, StandardDispatcher, StandardDispatchOutput, ) diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index ff9edc58b..adabae9d7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton if TYPE_CHECKING: from sglang.srt.configs.model_config import ModelConfig - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput - from sglang.srt.speculative.ngram_utils import NgramVerifyInput - from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 @@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = ( - None - ) + # spec_info: Optional[SpecInput] = None + spec_info: Optional[SpecInput] = None # Whether to return hidden states return_hidden_states: bool = False @@ -1995,9 +1992,9 @@ class ModelWorkerBatch: # Speculative decoding spec_algorithm: SpeculativeAlgorithm = None - spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = ( - None - ) + + spec_info: Optional[SpecInput] = None + # If set, the output of the batch contains the hidden states of the run. capture_hidden_mode: CaptureHiddenMode = None hicache_consumer_index: int = -1 diff --git a/python/sglang/srt/model_executor/cpu_graph_runner.py b/python/sglang/srt/model_executor/cpu_graph_runner.py index bc1e5c5b8..f1f7aa7b0 100644 --- a/python/sglang/srt/model_executor/cpu_graph_runner.py +++ b/python/sglang/srt/model_executor/cpu_graph_runner.py @@ -607,7 +607,7 @@ class CPUGraphRunner: def get_spec_info(self, num_tokens: int): spec_info = None if self.model_runner.spec_algorithm.is_eagle(): - from sglang.srt.speculative.eagle_utils import EagleVerifyInput + from sglang.srt.speculative.eagle_info import EagleVerifyInput if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen.") diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 2ed78ea58..864ade8b3 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -821,7 +821,7 @@ class CudaGraphRunner: self.model_runner.spec_algorithm.is_eagle() or self.model_runner.spec_algorithm.is_standalone() ): - from sglang.srt.speculative.eagle_utils import EagleVerifyInput + from sglang.srt.speculative.eagle_info import EagleVerifyInput if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen.") diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index 017b5863c..52e96016d 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import ( get_attention_tp_size, set_dp_buffer_len, ) -from sglang.srt.layers.rotary_embedding import MRotaryEmbedding -from sglang.srt.utils import ( - flatten_nested_list, - get_compiler_backend, - is_npu, - support_triton, -) +from sglang.srt.utils import get_compiler_backend, is_npu, support_triton if TYPE_CHECKING: from sglang.srt.layers.attention.base_attn_backend import AttentionBackend @@ -60,8 +54,7 @@ if TYPE_CHECKING: from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo - from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput - from sglang.srt.speculative.spec_info import SpeculativeAlgorithm + from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm _is_npu = is_npu() @@ -293,7 +286,7 @@ class ForwardBatch: global_forward_mode: Optional[ForwardMode] = None # Speculative decoding - spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None + spec_info: Optional[SpecInput] = None spec_algorithm: SpeculativeAlgorithm = None capture_hidden_mode: CaptureHiddenMode = None @@ -364,33 +357,14 @@ class ForwardBatch: # For MLP sync if batch.global_num_tokens is not None: - from sglang.srt.speculative.eagle_utils import ( - EagleDraftInput, - EagleVerifyInput, - ) - assert batch.global_num_tokens_for_logprob is not None + # process global_num_tokens and global_num_tokens_for_logprob if batch.spec_info is not None: - if isinstance(batch.spec_info, EagleDraftInput): - global_num_tokens = [ - x * batch.spec_info.num_tokens_per_batch - for x in batch.global_num_tokens - ] - global_num_tokens_for_logprob = [ - x * batch.spec_info.num_tokens_for_logprob_per_batch - for x in batch.global_num_tokens_for_logprob - ] - else: - assert isinstance(batch.spec_info, EagleVerifyInput) - global_num_tokens = [ - x * batch.spec_info.draft_token_num - for x in batch.global_num_tokens - ] - global_num_tokens_for_logprob = [ - x * batch.spec_info.draft_token_num - for x in batch.global_num_tokens_for_logprob - ] + spec_info: SpecInput = batch.spec_info + global_num_tokens, global_num_tokens_for_logprob = ( + spec_info.get_spec_adjusted_global_num_tokens(batch) + ) else: global_num_tokens = batch.global_num_tokens global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob @@ -669,9 +643,6 @@ class ForwardBatch: ) def prepare_mlp_sync_batch(self, model_runner: ModelRunner): - - from sglang.srt.speculative.eagle_utils import EagleDraftInput - assert self.global_num_tokens_cpu is not None assert self.global_num_tokens_for_logprob_cpu is not None @@ -768,7 +739,8 @@ class ForwardBatch: if self.extend_seq_lens is not None: self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs) - if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput): + if self.spec_info is not None and self.spec_info.is_draft_input(): + # FIXME(lsyin): remove this isinstance logic spec_info = self.spec_info self.output_cache_loc_backup = self.out_cache_loc self.hidden_states_backup = spec_info.hidden_states diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 66d2d5a34..e03626988 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.speculative.eagle_utils import EagleDraftInput +from sglang.srt.speculative.eagle_info import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index 8340b0ca8..edb37db27 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk +from sglang.srt.speculative.eagle_info import EagleDraftInput +from sglang.srt.speculative.spec_utils import fast_topk from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_info.py similarity index 58% rename from python/sglang/srt/speculative/eagle_utils.py rename to python/sglang/srt/speculative/eagle_info.py index 03270b48f..18a787256 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_info.py @@ -1,235 +1,52 @@ -from __future__ import annotations - -import copy import logging -import os -import time +from copy import copy from dataclasses import dataclass -from typing import List, Optional +from typing import List, Optional, Tuple import torch import torch.nn.functional as F -import triton -import triton.language as tl from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject -from sglang.srt.environ import envs from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ( - Req, ScheduleBatch, get_last_loc, global_server_args_dict, ) from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import ( + SIMULATE_ACC_LEN, + TREE_SPEC_KERNEL_AVAILABLE, + _generate_simulated_accept_index, + align_evict_mask_to_page_size, + assign_req_to_token_pool, + create_accept_length_filter, + create_extend_after_decode_spec_info, + filter_finished_cache_loc_kernel, + get_src_tgt_cache_loc, + get_target_cache_loc, +) from sglang.srt.utils import is_cuda, is_hip, next_power_of_2 if is_cuda(): from sgl_kernel import ( - fast_topk, top_k_renorm_prob, top_p_renorm_prob, tree_speculative_sampling_target_only, verify_tree_greedy, ) elif is_hip(): - from sgl_kernel import fast_topk, verify_tree_greedy - + from sgl_kernel import verify_tree_greedy logger = logging.getLogger(__name__) -# Simulate acceptance length for benchmarking purposes -SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 -SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() - -TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly - -TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals() - - @dataclass -class EagleDraftInput: - # The inputs for decode - # shape: (b, topk) - topk_p: torch.Tensor = None - topk_index: torch.Tensor = None - # shape: (b, hidden_size) - hidden_states: torch.Tensor = None - capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL - - # Inputs for extend - # shape: (b,) - verified_id: torch.Tensor = None - accept_length: torch.Tensor = None - accept_length_cpu: List[int] = None - - # Inputs for the attention backends - # shape: (b + 1,) - kv_indptr: torch.Tensor = None - kv_indices: torch.Tensor = None - - # Shape info for padding - num_tokens_per_batch: int = -1 - num_tokens_for_logprob_per_batch: int = -1 - - # Inputs for draft extend - # shape: (b,) - seq_lens_for_draft_extend: torch.Tensor = None - req_pool_indices_for_draft_extend: torch.Tensor = None - - def prepare_for_extend(self, batch: ScheduleBatch): - - if batch.forward_mode.is_idle(): - return - - # Prefill only generate 1 token. - assert len(self.verified_id) == len(batch.seq_lens) - - pt = 0 - for i, extend_len in enumerate(batch.extend_lens): - input_ids = batch.input_ids[pt : pt + extend_len] - batch.input_ids[pt : pt + extend_len] = torch.cat( - (input_ids[1:], self.verified_id[i].reshape(1)) - ) - pt += extend_len - - @classmethod - def create_idle_input( - cls, - device: torch.device, - hidden_size: int, - dtype: torch.dtype, - topk: int, - capture_hidden_mode: CaptureHiddenMode, - ): - return cls( - verified_id=torch.empty((0,), device=device, dtype=torch.int32), - hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), - topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), - topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), - capture_hidden_mode=capture_hidden_mode, - accept_length=torch.empty((0,), device=device, dtype=torch.int32), - accept_length_cpu=[], - ) - - def prepare_extend_after_decode( - self, - batch: ScheduleBatch, - speculative_num_steps: int, - ): - - if batch.forward_mode.is_idle(): - return - - batch.input_ids = self.verified_id - batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] - batch.extend_num_tokens = sum(batch.extend_lens) - batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend - batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend - batch.return_logprob = False - batch.return_hidden_states = False - - self.capture_hidden_mode = CaptureHiddenMode.LAST - self.accept_length.add_(1) - self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) - self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) - - create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( - batch.input_ids, - batch.seq_lens, - self.accept_length, - self.positions, - self.verified_id, - next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), - ) - - def generate_attn_arg_prefill( - self, - req_pool_indices: torch.Tensor, - paged_kernel_lens: torch.Tensor, - paged_kernel_lens_sum: int, - req_to_token: torch.Tensor, - ): - bs = self.accept_length.numel() - qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) - cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") - cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - - if paged_kernel_lens_sum is None: - paged_kernel_lens_sum = cum_kv_seq_len[-1] - - kv_indices = torch.empty( - paged_kernel_lens_sum, dtype=torch.int32, device="cuda" - ) - - create_flashinfer_kv_indices_triton[(bs,)]( - req_to_token, - req_pool_indices, - paged_kernel_lens, - cum_kv_seq_len, - None, - kv_indices, - req_to_token.size(1), - ) - return kv_indices, cum_kv_seq_len, qo_indptr, None - - def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): - if has_been_filtered: - # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` - # therefore, we don't need to filter the batch again in scheduler - if len(new_indices) != len(self.topk_p): - logger.warning( - f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen" - ) - self.topk_p = self.topk_p[: len(new_indices)] - self.topk_index = self.topk_index[: len(new_indices)] - self.hidden_states = self.hidden_states[: len(new_indices)] - self.verified_id = self.verified_id[: len(new_indices)] - else: - # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` - self.topk_p = self.topk_p[new_indices] - self.topk_index = self.topk_index[new_indices] - self.hidden_states = self.hidden_states[new_indices] - self.verified_id = self.verified_id[new_indices] - - def merge_batch(self, spec_info: EagleDraftInput): - if self.hidden_states is None: - self.hidden_states = spec_info.hidden_states - self.verified_id = spec_info.verified_id - self.topk_p = spec_info.topk_p - self.topk_index = spec_info.topk_index - return - if spec_info.hidden_states is None: - return - self.hidden_states = torch.cat( - [self.hidden_states, spec_info.hidden_states], axis=0 - ) - self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) - self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) - self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) - - -@dataclass -class EagleVerifyOutput: - # Draft input batch - draft_input: EagleDraftInput - # Logit outputs from target worker - logits_output: LogitsProcessorOutput - # Accepted token ids including the bonus token - verified_id: torch.Tensor - # Accepted token length per sequence in a batch in CPU. - accept_length_per_req_cpu: List[int] - # Accepted indices from logits_output.next_token_logits - accepted_indices: torch.Tensor - - -@dataclass -class EagleVerifyInput: +class EagleVerifyInput(SpecInput): draft_token: torch.Tensor custom_mask: torch.Tensor positions: torch.Tensor @@ -245,6 +62,12 @@ class EagleVerifyInput: seq_lens_cpu: torch.Tensor grammar: BaseGrammarObject = None + def __post_init__(self): + super().__init__(SpecInputType.EAGLE_VERIFY) + + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + @classmethod def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int): return cls( @@ -724,574 +547,184 @@ class EagleVerifyInput: ) -@triton.jit -def create_extend_after_decode_spec_info( - verified_id, - seq_lens, - accept_lens, - positions, - new_verified_id, - bs_upper: tl.constexpr, -): - pid = tl.program_id(axis=0) - offsets = tl.arange(0, bs_upper) - seq_length = tl.load(seq_lens + pid) - accept_length = tl.load(accept_lens + pid) +@dataclass +class EagleDraftInput(SpecInput): + # The inputs for decode + # shape: (b, topk) + topk_p: torch.Tensor = None + topk_index: torch.Tensor = None + # shape: (b, hidden_size) + hidden_states: torch.Tensor = None + capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL - accept_len_cumsum = tl.sum( - tl.load(accept_lens + offsets, mask=offsets < pid, other=0) - ) - positions_ptr = positions + accept_len_cumsum - mask = offsets < accept_length - tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask) + # Inputs for extend + # shape: (b,) + verified_id: torch.Tensor = None + accept_length: torch.Tensor = None + accept_length_cpu: List[int] = None - accept_len_cumsum += accept_length - 1 - verified_id_data = tl.load(verified_id + accept_len_cumsum) - tl.store(new_verified_id + pid, verified_id_data) + # Inputs for the attention backends + # shape: (b + 1,) + kv_indptr: torch.Tensor = None + kv_indices: torch.Tensor = None + # Shape info for padding + num_tokens_per_batch: int = -1 + num_tokens_for_logprob_per_batch: int = -1 -@triton.jit -def assign_req_to_token_pool( - req_pool_indices, - req_to_token, - start_offset, - end_offset, - out_cache_loc, - pool_len: tl.constexpr, - bs_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 32 - pid = tl.program_id(axis=0) - kv_start = tl.load(start_offset + pid) - kv_end = tl.load(end_offset + pid) - token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + # Inputs for draft extend + # shape: (b,) + seq_lens_for_draft_extend: torch.Tensor = None + req_pool_indices_for_draft_extend: torch.Tensor = None - length_offset = tl.arange(0, bs_upper) - start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) - end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) - out_offset = tl.sum(end - start, axis=0) + def __post_init__(self): + super().__init__(SpecInputType.EAGLE_DRAFT) - out_cache_ptr = out_cache_loc + out_offset + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch - save_offset = tl.arange(0, BLOCK_SIZE) + kv_start - load_offset = tl.arange(0, BLOCK_SIZE) + def prepare_for_extend(self, batch: ScheduleBatch): - num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) - for _ in range(num_loop): - mask = save_offset < kv_end - data = tl.load(out_cache_ptr + load_offset, mask=mask) - tl.store(token_pool + save_offset, data, mask=mask) - save_offset += BLOCK_SIZE - load_offset += BLOCK_SIZE + if batch.forward_mode.is_idle(): + return + # Prefill only generate 1 token. + assert len(self.verified_id) == len(batch.seq_lens) -@triton.jit -def assign_draft_cache_locs( - req_pool_indices, - req_to_token, - seq_lens, - extend_lens, - num_new_pages_per_topk, - out_cache_loc, - pool_len: tl.constexpr, - topk: tl.constexpr, - speculative_num_steps: tl.constexpr, - page_size: tl.constexpr, - bs_upper: tl.constexpr, - iter_upper: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 128 - pid = tl.program_id(axis=0) + pt = 0 + for i, extend_len in enumerate(batch.extend_lens): + input_ids = batch.input_ids[pt : pt + extend_len] + batch.input_ids[pt : pt + extend_len] = torch.cat( + (input_ids[1:], self.verified_id[i].reshape(1)) + ) + pt += extend_len - if page_size == 1 or topk == 1: - copy_len = topk * speculative_num_steps - out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps - else: - bs_offset = tl.arange(0, bs_upper) - copy_len = tl.load(extend_lens + pid) - cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) - out_cache_ptr = out_cache_loc + cum_copy_len - - # Part 1: Copy from out_cache_loc to req_to_token - kv_start = tl.load(seq_lens + pid) - token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len - num_loop = tl.cdiv(copy_len, BLOCK_SIZE) - for i in range(num_loop): - copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE - mask = copy_offset < copy_len - data = tl.load(out_cache_ptr + copy_offset, mask=mask) - tl.store(token_pool + kv_start + copy_offset, data, mask=mask) - - if page_size == 1 or topk == 1: - return - - # Part 2: Copy the indices for the last partial page - prefix_len = tl.load(seq_lens + pid) - last_page_len = prefix_len % page_size - offsets = tl.arange(0, page_size) - mask = offsets < last_page_len - num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) - prefix_base = token_pool + prefix_len - last_page_len - - for topk_id in range(topk): - value = tl.load(prefix_base + offsets, mask=mask) - tl.store( - prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, - value, - mask=mask, - ) - - # Part 3: Remove the padding in out_cache_loc - iter_offest = tl.arange(0, iter_upper) - for topk_id in range(topk): - indices = tl.load( - prefix_base - + topk_id * num_new_pages_per_topk_ * page_size - + last_page_len - + iter_offest, - mask=iter_offest < speculative_num_steps, - ) - tl.store( - out_cache_loc - + pid * topk * speculative_num_steps - + topk_id * speculative_num_steps - + iter_offest, - indices, - mask=iter_offest < speculative_num_steps, - ) - - -@triton.jit -def generate_draft_decode_kv_indices( - req_pool_indices, - req_to_token, - paged_kernel_lens, - kv_indices, - kv_indptr, - positions, - pool_len: tl.constexpr, - kv_indices_stride: tl.constexpr, - kv_indptr_stride: tl.constexpr, - bs_upper: tl.constexpr, - iter_upper: tl.constexpr, - num_tokens_upper: tl.constexpr, - page_size: tl.constexpr, -): - BLOCK_SIZE: tl.constexpr = 128 - iters = tl.program_id(axis=0) - bid = tl.program_id(axis=1) - topk_id = tl.program_id(axis=2) - - num_steps = tl.num_programs(axis=0) - num_seqs = tl.num_programs(axis=1) - topk = tl.num_programs(axis=2) - - kv_indices += kv_indices_stride * iters - kv_indptr += kv_indptr_stride * iters - iters += 1 - - load_offset = tl.arange(0, bs_upper) - seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0) - seq_len = tl.load(paged_kernel_lens + bid) - cum_seq_len = tl.sum(seq_lens) - - # Update kv_indices - kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) - kv_ptr = kv_indices + kv_offset - token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len - - kv_offset = tl.arange(0, BLOCK_SIZE) - num_loop = tl.cdiv(seq_len, BLOCK_SIZE) - for _ in range(num_loop): - mask = kv_offset < seq_len - data = tl.load(token_pool_ptr + kv_offset, mask=mask) - tl.store(kv_ptr + kv_offset, data, mask=mask) - kv_offset += BLOCK_SIZE - - extend_offset = tl.arange(0, iter_upper) - if page_size == 1 or topk == 1: - extend_data = tl.load( - token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), - mask=extend_offset < iters, - ) - else: - prefix_len = seq_len - last_page_len = prefix_len % page_size - num_new_pages_per_topk = ( - last_page_len + num_steps + page_size - 1 - ) // page_size - prefix_base = seq_len // page_size * page_size - start = ( - prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len - ) - extend_data = tl.load( - token_pool_ptr + start + extend_offset, - mask=extend_offset < iters, - ) - - tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) - - # Update kv_indptr - bs_offset = tl.arange(0, num_tokens_upper) - - zid = bid * topk + topk_id - if zid == 0: - zid = num_seqs * topk - positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0) - base = tl.sum(positions) - tl.store(kv_indptr + zid, base + zid * iters) - - -@triton.jit -def align_evict_mask_to_page_size( - seq_lens, - evict_mask, - page_size: tl.constexpr, - num_draft_tokens: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -): - t_range = tl.arange(0, BLOCK_SIZE) - - bid = tl.program_id(axis=0) - seq_len = tl.load(seq_lens + bid) - io_mask = t_range < num_draft_tokens - mask_row = tl.load( - evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0 - ) - - num_trues = tl.sum(mask_row) - num_false = num_draft_tokens - num_trues - - start = (seq_len + num_false - 1) // page_size * page_size - seq_len - for i in range(max(start, 0), min(start + page_size, num_draft_tokens)): - tl.store(evict_mask + bid * num_draft_tokens + i, False) - - -@triton.jit -def get_target_cache_loc( - tgt_cache_loc, - to_free_slots, - accept_length, - to_free_num_slots, - out_cache_loc, - num_verify_tokens: tl.constexpr, - num_verify_tokens_upper: tl.constexpr, - bs_upper: tl.constexpr, -): - bid = tl.program_id(axis=0) - offset = tl.arange(0, num_verify_tokens_upper) - bs_offset = tl.arange(0, bs_upper) - - # write the first part to tgt_cache_loc - accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) - tgt_cache_loc_start = tl.sum(accept_len_all) + bid - copy_len = tl.load(accept_length + bid) + 1 - out_cache_loc_row = tl.load( - out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len - ) - tl.store( - tgt_cache_loc + tgt_cache_loc_start + offset, - out_cache_loc_row, - mask=offset < copy_len, - ) - - # write the second part to to_free_num_pages - to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) - to_free_num_slots_cur = tl.load(to_free_num_slots + bid) - out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur - to_free_slots_start = tl.sum(to_free_num_slots_all) - - copy_len = to_free_num_slots_cur - out_cache_loc_row = tl.load( - out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, - mask=offset < copy_len, - ) - tl.store( - to_free_slots + to_free_slots_start + offset, - out_cache_loc_row, - mask=offset < copy_len, - ) - - -@torch.compile(dynamic=True) -def get_src_tgt_cache_loc( - seq_lens: torch.Tensor, - out_cache_loc: torch.Tensor, - accept_index: torch.Tensor, - accept_length: torch.Tensor, - draft_token_num: int, - page_size: int, -): - src_cache_loc = out_cache_loc[accept_index] - tgt_cache_loc = torch.empty_like(src_cache_loc) - extended_len = seq_lens + draft_token_num - keep_len = torch.minimum( - (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size, - extended_len, - ) - to_free_num_slots = extended_len - keep_len - return src_cache_loc, tgt_cache_loc, to_free_num_slots - - -@triton.jit -def filter_finished_cache_loc_kernel( - out_cache_loc, - tgt_cache_loc, - accept_length, - accept_length_filter, - bs_upper: tl.constexpr, - num_verify_tokens_upper: tl.constexpr, -): - bid = tl.program_id(0) - bs_offset = tl.arange(0, bs_upper) - - accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) - old_start = tl.sum(accept_length_all) + bid - - accept_length_filter_all = tl.load( - accept_length_filter + bs_offset, mask=bs_offset < bid - ) - new_start = tl.sum(accept_length_filter_all) - - copy_len = tl.load(accept_length_filter + bid) - copy_offset = tl.arange(0, num_verify_tokens_upper) - value = tl.load( - tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len - ) - tl.store( - out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len - ) - - -@torch.compile(dynamic=True) -def create_accept_length_filter( - accept_length: torch.Tensor, - unfinished_index_device: torch.Tensor, - seq_lens: torch.Tensor, -): - accept_length_filter = torch.zeros_like(accept_length) - accept_length_filter[unfinished_index_device] = ( - accept_length[unfinished_index_device] + 1 - ) - seq_lens.add_(accept_length + 1) - return accept_length_filter - - -@torch.compile(dynamic=True) -def select_top_k_tokens( - i: int, - topk_p: torch.Tensor, - topk_index: torch.Tensor, - hidden_states: torch.Tensor, - scores: torch.Tensor, - topk: int, -): - if i == 0: - # The first step after extend - input_ids = topk_index.flatten() - hidden_states = hidden_states.repeat_interleave(topk, dim=0) - scores = topk_p # shape: (b, topk) - - tree_info = ( - topk_p.unsqueeze(1), # shape: (b, 1, topk) - topk_index, # shape: (b, topk) - torch.arange(-1, topk, dtype=torch.long, device="cuda") - .unsqueeze(0) - .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) - ) - else: - # The later decode steps - expand_scores = torch.mul( - scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) - ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) - topk_cs_p, topk_cs_index = fast_topk( - expand_scores.flatten(start_dim=1), topk, dim=-1 - ) # (b, topk) - scores = topk_cs_p # shape: (b, topk) - - topk_index = topk_index.reshape(-1, topk**2) - input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() - - if hidden_states.shape[0] > 0: - selected_input_index = topk_cs_index.flatten() // topk + torch.arange( - 0, hidden_states.shape[0], step=topk, device="cuda" - ).repeat_interleave(topk) - hidden_states = hidden_states[selected_input_index, :] - - tree_info = ( - expand_scores, # shape: (b, topk, topk) - topk_index, # shape: (b, topk * topk) - topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) - ) - - return input_ids, hidden_states, scores, tree_info - - -def _generate_simulated_accept_index( - accept_index, - predict, - accept_length, - bs, - spec_steps, - simulate_acc_len: float = SIMULATE_ACC_LEN, - simulate_acc_method: str = SIMULATE_ACC_METHOD, -): - assert simulate_acc_len > 0.0 - - if simulate_acc_method == "multinomial": - simulated_values = torch.normal( - mean=simulate_acc_len, - std=1.0, - size=(1,), - device="cpu", - ) - # clamp simulated values to be between 1 and self.spec_steps - simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) - simulate_acc_len = int(simulated_values.round().item()) - elif simulate_acc_method == "match-expected": - # multinomial sampling does not match the expected length - # we keep it for the sake of compatibility of existing tests - # but it's better to use "match-expected" for the cases that need to - # match the expected length, One caveat is that this will only sample - # either round down or round up of the expected length - simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len)) - lower = int(simulate_acc_len // 1) - upper = lower + 1 if lower < spec_steps + 1 else lower - if lower == upper: - simulate_acc_len = lower - else: - weight_upper = simulate_acc_len - lower - weight_lower = 1.0 - weight_upper - probs = torch.tensor([weight_lower, weight_upper], device="cpu") - sampled_index = torch.multinomial(probs, num_samples=1) - simulate_acc_len = lower if sampled_index == 0 else upper - else: - raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}") - - accept_indx_first_col = accept_index[:, 0].view(-1, 1) - sim_accept_index = torch.full( - (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda" - ) - sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange( - simulate_acc_len, device=accept_index.device - ) - accept_length.fill_(simulate_acc_len - 1) - predict.fill_(100) # some legit token id - return sim_accept_index - - -def traverse_tree( - retrieve_next_token: torch.Tensor, - retrieve_next_sibling: torch.Tensor, - draft_tokens: torch.Tensor, - grammar: BaseGrammarObject, - allocate_token_bitmask: torch.Tensor, -): - """ - Traverse the tree constructed by the draft model to generate the logits mask. - """ - assert ( - retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape - ) - - allocate_token_bitmask.fill_(0) - - def dfs( - curr: int, - retrieve_next_token: torch.Tensor, - retrieve_next_sibling: torch.Tensor, - parent_pos: int, + @classmethod + def create_idle_input( + cls, + device: torch.device, + hidden_size: int, + dtype: torch.dtype, + topk: int, + capture_hidden_mode: CaptureHiddenMode, ): - if curr == 0: - # the first token generated by the target model, and thus it is always - # accepted from the previous iteration - accepted = True - else: - parent_bitmask = allocate_token_bitmask[parent_pos] - curr_token_id = draft_tokens[curr] - # 32 boolean bitmask values are packed into 32-bit integers - accepted = ( - parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) - ) != 0 + return cls( + verified_id=torch.empty((0,), device=device, dtype=torch.int32), + hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype), + topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), + topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), + capture_hidden_mode=capture_hidden_mode, + accept_length=torch.empty((0,), device=device, dtype=torch.int32), + accept_length_cpu=[], + ) - if accepted: - if curr != 0: - # Accept the current token - grammar.accept_token(draft_tokens[curr]) - if not grammar.is_terminated(): - # Generate the bitmask for the current token - grammar.fill_vocab_mask(allocate_token_bitmask, curr) - if retrieve_next_token[curr] != -1: - # Visit the child node - dfs( - retrieve_next_token[curr], - retrieve_next_token, - retrieve_next_sibling, - curr, - ) + def prepare_extend_after_decode( + self, + batch: ScheduleBatch, + speculative_num_steps: int, + ): - if curr != 0: - # Rollback the current token - grammar.rollback(1) + if batch.forward_mode.is_idle(): + return - if retrieve_next_sibling[curr] != -1: - # Visit the sibling node - dfs( - retrieve_next_sibling[curr], - retrieve_next_token, - retrieve_next_sibling, - parent_pos, - ) + batch.input_ids = self.verified_id + batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu] + batch.extend_num_tokens = sum(batch.extend_lens) + batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend + batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend + batch.return_logprob = False + batch.return_hidden_states = False - dfs(0, retrieve_next_token, retrieve_next_sibling, -1) + self.capture_hidden_mode = CaptureHiddenMode.LAST + self.accept_length.add_(1) + self.positions = torch.empty_like(batch.input_ids, dtype=torch.long) + self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32) + create_extend_after_decode_spec_info[(len(batch.seq_lens),)]( + batch.input_ids, + batch.seq_lens, + self.accept_length, + self.positions, + self.verified_id, + next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))), + ) -def generate_token_bitmask( - reqs: List[Req], - verify_input: EagleVerifyInput, - retrieve_next_token_cpu: torch.Tensor, - retrieve_next_sibling_cpu: torch.Tensor, - draft_tokens_cpu: torch.Tensor, - vocab_size: int, -): - """ - Generate the logit mask for structured output. - Draft model's token can be either valid or invalid with respect to the grammar. - We need to perform DFS to - 1. figure out which tokens are accepted by the grammar. - 2. if so, what is the corresponding logit mask. - """ + def generate_attn_arg_prefill( + self, + req_pool_indices: torch.Tensor, + paged_kernel_lens: torch.Tensor, + paged_kernel_lens_sum: int, + req_to_token: torch.Tensor, + ): + bs = self.accept_length.numel() + qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0) + cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda") + cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0) - num_draft_tokens = draft_tokens_cpu.shape[-1] + if paged_kernel_lens_sum is None: + paged_kernel_lens_sum = cum_kv_seq_len[-1] - allocate_token_bitmask = None - assert len(reqs) == retrieve_next_token_cpu.shape[0] - grammar = None - for i, req in enumerate(reqs): - if req.grammar is not None: - if allocate_token_bitmask is None: - allocate_token_bitmask = req.grammar.allocate_vocab_mask( - vocab_size=vocab_size, - batch_size=draft_tokens_cpu.numel(), - device="cpu", - ) - grammar = req.grammar - s = time.perf_counter() - traverse_tree( - retrieve_next_token_cpu[i], - retrieve_next_sibling_cpu[i], - draft_tokens_cpu[i], - req.grammar, - allocate_token_bitmask[ - i * num_draft_tokens : (i + 1) * num_draft_tokens - ], - ) - tree_traverse_time = time.perf_counter() - s - if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD: + kv_indices = torch.empty( + paged_kernel_lens_sum, dtype=torch.int32, device="cuda" + ) + + create_flashinfer_kv_indices_triton[(bs,)]( + req_to_token, + req_pool_indices, + paged_kernel_lens, + cum_kv_seq_len, + None, + kv_indices, + req_to_token.size(1), + ) + return kv_indices, cum_kv_seq_len, qo_indptr, None + + def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True): + if has_been_filtered: + # in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index` + # therefore, we don't need to filter the batch again in scheduler + if len(new_indices) != len(self.topk_p): logger.warning( - f"Bit mask generation took {tree_traverse_time} seconds with " - f"grammar: {req.grammar}" + f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen" ) + self.topk_p = self.topk_p[: len(new_indices)] + self.topk_index = self.topk_index[: len(new_indices)] + self.hidden_states = self.hidden_states[: len(new_indices)] + self.verified_id = self.verified_id[: len(new_indices)] + else: + # in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index` + self.topk_p = self.topk_p[new_indices] + self.topk_index = self.topk_index[new_indices] + self.hidden_states = self.hidden_states[new_indices] + self.verified_id = self.verified_id[new_indices] - verify_input.grammar = grammar - return allocate_token_bitmask + def merge_batch(self, spec_info: "EagleDraftInput"): + if self.hidden_states is None: + self.hidden_states = spec_info.hidden_states + self.verified_id = spec_info.verified_id + self.topk_p = spec_info.topk_p + self.topk_index = spec_info.topk_index + return + if spec_info.hidden_states is None: + return + self.hidden_states = torch.cat( + [self.hidden_states, spec_info.hidden_states], axis=0 + ) + self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0) + self.topk_p = torch.cat([self.topk_p, spec_info.topk_p]) + self.topk_index = torch.cat([self.topk_index, spec_info.topk_index]) + + +@dataclass +class EagleVerifyOutput: + # Draft input batch + draft_input: EagleDraftInput + # Logit outputs from target worker + logits_output: LogitsProcessorOutput + # Accepted token ids including the bonus token + verified_id: torch.Tensor + # Accepted token length per sequence in a batch in CPU. + accept_length_per_req_cpu: List[int] + # Accepted indices from logits_output.next_token_logits + accepted_indices: torch.Tensor diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index abc13da9d..1782d6da0 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import ( EAGLEDraftExtendCudaGraphRunner, ) -from sglang.srt.speculative.eagle_utils import ( +from sglang.srt.speculative.eagle_info import ( EagleDraftInput, EagleVerifyInput, EagleVerifyOutput, +) +from sglang.srt.speculative.spec_info import SpeculativeAlgorithm +from sglang.srt.speculative.spec_utils import ( assign_draft_cache_locs, fast_topk, generate_token_bitmask, select_top_k_tokens, ) -from sglang.srt.speculative.spec_info import SpeculativeAlgorithm from sglang.srt.utils import ( empty_context, get_available_gpu_memory, diff --git a/python/sglang/srt/speculative/ngram_utils.py b/python/sglang/srt/speculative/ngram_utils.py index d0e80c0a4..ad4a332bd 100644 --- a/python/sglang/srt/speculative/ngram_utils.py +++ b/python/sglang/srt/speculative/ngram_utils.py @@ -2,7 +2,7 @@ from __future__ import annotations import copy import logging -from typing import Optional +from typing import Optional, Tuple import torch import triton @@ -13,6 +13,7 @@ from dataclasses import dataclass import torch.nn.functional as F +from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import apply_custom_logit_processor from sglang.srt.managers.schedule_batch import ( @@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import ( global_server_args_dict, ) from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo -from sglang.srt.speculative.eagle_utils import ( +from sglang.srt.speculative.spec_info import SpecInput, SpecInputType +from sglang.srt.speculative.spec_utils import ( TREE_SPEC_KERNEL_AVAILABLE, assign_req_to_token_pool, - create_flashinfer_kv_indices_triton, get_src_tgt_cache_loc, get_target_cache_loc, ) @@ -42,7 +43,7 @@ elif is_hip(): @dataclass -class NgramVerifyInput: +class NgramVerifyInput(SpecInput): def __init__( self, draft_token: torch.Tensor, @@ -53,6 +54,7 @@ class NgramVerifyInput: retrive_next_sibling: torch.Tensor, draft_token_num: int, ): + super().__init__(SpecInputType.NGRAM_VERIFY) self.draft_token = draft_token self.custom_mask = tree_mask self.positions = positions @@ -62,6 +64,9 @@ class NgramVerifyInput: self.draft_token_num = draft_token_num self.device = self.custom_mask.device + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + return self.draft_token_num, self.draft_token_num + def prepare_for_verify(self, batch: ScheduleBatch, page_size: int): if batch.forward_mode.is_idle(): return diff --git a/python/sglang/srt/speculative/ngram_worker.py b/python/sglang/srt/speculative/ngram_worker.py index cb0155911..69dc83b1f 100644 --- a/python/sglang/srt/speculative/ngram_worker.py +++ b/python/sglang/srt/speculative/ngram_worker.py @@ -1,8 +1,5 @@ import logging -import os -import threading -import time -from typing import TYPE_CHECKING, List, Optional, Union +from typing import List, Optional import numpy as np import torch @@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache from sglang.srt.speculative.ngram_utils import NgramVerifyInput from sglang.srt.speculative.spec_info import SpeculativeAlgorithm -from sglang.srt.utils import broadcast_pyobj logger = logging.getLogger(__name__) diff --git a/python/sglang/srt/speculative/spec_info.py b/python/sglang/srt/speculative/spec_info.py index 64a02f19e..389d57ed1 100644 --- a/python/sglang/srt/speculative/spec_info.py +++ b/python/sglang/srt/speculative/spec_info.py @@ -1,4 +1,8 @@ +from abc import ABC, abstractmethod from enum import IntEnum, auto +from typing import List, Tuple + +from sglang.srt.managers.schedule_batch import ModelWorkerBatch class SpeculativeAlgorithm(IntEnum): @@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum): if name is not None: name = name.upper() return name_map[name] + + +class SpecInputType(IntEnum): + # NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends. + # If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it + EAGLE_DRAFT = auto() + EAGLE_VERIFY = auto() + NGRAM_VERIFY = auto() + + +class SpecInput(ABC): + def __init__(self, spec_input_type: SpecInputType): + self.spec_input_type = spec_input_type + + def is_draft_input(self) -> bool: + # FIXME: remove this function which is only used for assertion + # or use another variable name like `draft_input` to substitute `spec_info` + return self.spec_input_type == SpecInputType.EAGLE_DRAFT + + def is_verify_input(self) -> bool: + return self.spec_input_type in { + SpecInputType.EAGLE_VERIFY, + SpecInputType.NGRAM_VERIFY, + } + + @abstractmethod + def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]: + pass + + def get_spec_adjusted_global_num_tokens( + self, forward_batch: ModelWorkerBatch + ) -> Tuple[List[int], List[int]]: + c1, c2 = self.get_spec_adjust_token_coefficient() + global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens] + global_num_tokens_for_logprob = [ + x * c2 for x in forward_batch.global_num_tokens_for_logprob + ] + return global_num_tokens, global_num_tokens_for_logprob diff --git a/python/sglang/srt/speculative/spec_utils.py b/python/sglang/srt/speculative/spec_utils.py new file mode 100644 index 000000000..6640c077d --- /dev/null +++ b/python/sglang/srt/speculative/spec_utils.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import logging +import os +import time +from typing import TYPE_CHECKING, List + +import torch +import triton +import triton.language as tl + +from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject +from sglang.srt.environ import envs +from sglang.srt.managers.schedule_batch import Req +from sglang.srt.utils import is_cuda, is_hip + +if is_cuda(): + from sgl_kernel import fast_topk +elif is_hip(): + from sgl_kernel import fast_topk + +if TYPE_CHECKING: + from sglang.srt.speculative.eagle_info import EagleVerifyInput + +logger = logging.getLogger(__name__) + + +# Simulate acceptance length for benchmarking purposes +SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 +SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() + +TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly + +TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals() + + +@triton.jit +def create_extend_after_decode_spec_info( + verified_id, + seq_lens, + accept_lens, + positions, + new_verified_id, + bs_upper: tl.constexpr, +): + pid = tl.program_id(axis=0) + offsets = tl.arange(0, bs_upper) + seq_length = tl.load(seq_lens + pid) + accept_length = tl.load(accept_lens + pid) + + accept_len_cumsum = tl.sum( + tl.load(accept_lens + offsets, mask=offsets < pid, other=0) + ) + positions_ptr = positions + accept_len_cumsum + mask = offsets < accept_length + tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask) + + accept_len_cumsum += accept_length - 1 + verified_id_data = tl.load(verified_id + accept_len_cumsum) + tl.store(new_verified_id + pid, verified_id_data) + + +@triton.jit +def assign_req_to_token_pool( + req_pool_indices, + req_to_token, + start_offset, + end_offset, + out_cache_loc, + pool_len: tl.constexpr, + bs_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 32 + pid = tl.program_id(axis=0) + kv_start = tl.load(start_offset + pid) + kv_end = tl.load(end_offset + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + + length_offset = tl.arange(0, bs_upper) + start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) + end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) + out_offset = tl.sum(end - start, axis=0) + + out_cache_ptr = out_cache_loc + out_offset + + save_offset = tl.arange(0, BLOCK_SIZE) + kv_start + load_offset = tl.arange(0, BLOCK_SIZE) + + num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) + for _ in range(num_loop): + mask = save_offset < kv_end + data = tl.load(out_cache_ptr + load_offset, mask=mask) + tl.store(token_pool + save_offset, data, mask=mask) + save_offset += BLOCK_SIZE + load_offset += BLOCK_SIZE + + +@triton.jit +def assign_draft_cache_locs( + req_pool_indices, + req_to_token, + seq_lens, + extend_lens, + num_new_pages_per_topk, + out_cache_loc, + pool_len: tl.constexpr, + topk: tl.constexpr, + speculative_num_steps: tl.constexpr, + page_size: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + pid = tl.program_id(axis=0) + + if page_size == 1 or topk == 1: + copy_len = topk * speculative_num_steps + out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps + else: + bs_offset = tl.arange(0, bs_upper) + copy_len = tl.load(extend_lens + pid) + cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) + out_cache_ptr = out_cache_loc + cum_copy_len + + # Part 1: Copy from out_cache_loc to req_to_token + kv_start = tl.load(seq_lens + pid) + token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len + num_loop = tl.cdiv(copy_len, BLOCK_SIZE) + for i in range(num_loop): + copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE + mask = copy_offset < copy_len + data = tl.load(out_cache_ptr + copy_offset, mask=mask) + tl.store(token_pool + kv_start + copy_offset, data, mask=mask) + + if page_size == 1 or topk == 1: + return + + # Part 2: Copy the indices for the last partial page + prefix_len = tl.load(seq_lens + pid) + last_page_len = prefix_len % page_size + offsets = tl.arange(0, page_size) + mask = offsets < last_page_len + num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) + prefix_base = token_pool + prefix_len - last_page_len + + for topk_id in range(topk): + value = tl.load(prefix_base + offsets, mask=mask) + tl.store( + prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, + value, + mask=mask, + ) + + # Part 3: Remove the padding in out_cache_loc + iter_offest = tl.arange(0, iter_upper) + for topk_id in range(topk): + indices = tl.load( + prefix_base + + topk_id * num_new_pages_per_topk_ * page_size + + last_page_len + + iter_offest, + mask=iter_offest < speculative_num_steps, + ) + tl.store( + out_cache_loc + + pid * topk * speculative_num_steps + + topk_id * speculative_num_steps + + iter_offest, + indices, + mask=iter_offest < speculative_num_steps, + ) + + +@triton.jit +def generate_draft_decode_kv_indices( + req_pool_indices, + req_to_token, + paged_kernel_lens, + kv_indices, + kv_indptr, + positions, + pool_len: tl.constexpr, + kv_indices_stride: tl.constexpr, + kv_indptr_stride: tl.constexpr, + bs_upper: tl.constexpr, + iter_upper: tl.constexpr, + num_tokens_upper: tl.constexpr, + page_size: tl.constexpr, +): + BLOCK_SIZE: tl.constexpr = 128 + iters = tl.program_id(axis=0) + bid = tl.program_id(axis=1) + topk_id = tl.program_id(axis=2) + + num_steps = tl.num_programs(axis=0) + num_seqs = tl.num_programs(axis=1) + topk = tl.num_programs(axis=2) + + kv_indices += kv_indices_stride * iters + kv_indptr += kv_indptr_stride * iters + iters += 1 + + load_offset = tl.arange(0, bs_upper) + seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0) + seq_len = tl.load(paged_kernel_lens + bid) + cum_seq_len = tl.sum(seq_lens) + + # Update kv_indices + kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) + kv_ptr = kv_indices + kv_offset + token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len + + kv_offset = tl.arange(0, BLOCK_SIZE) + num_loop = tl.cdiv(seq_len, BLOCK_SIZE) + for _ in range(num_loop): + mask = kv_offset < seq_len + data = tl.load(token_pool_ptr + kv_offset, mask=mask) + tl.store(kv_ptr + kv_offset, data, mask=mask) + kv_offset += BLOCK_SIZE + + extend_offset = tl.arange(0, iter_upper) + if page_size == 1 or topk == 1: + extend_data = tl.load( + token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), + mask=extend_offset < iters, + ) + else: + prefix_len = seq_len + last_page_len = prefix_len % page_size + num_new_pages_per_topk = ( + last_page_len + num_steps + page_size - 1 + ) // page_size + prefix_base = seq_len // page_size * page_size + start = ( + prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len + ) + extend_data = tl.load( + token_pool_ptr + start + extend_offset, + mask=extend_offset < iters, + ) + + tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) + + # Update kv_indptr + bs_offset = tl.arange(0, num_tokens_upper) + + zid = bid * topk + topk_id + if zid == 0: + zid = num_seqs * topk + positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0) + base = tl.sum(positions) + tl.store(kv_indptr + zid, base + zid * iters) + + +@triton.jit +def align_evict_mask_to_page_size( + seq_lens, + evict_mask, + page_size: tl.constexpr, + num_draft_tokens: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + t_range = tl.arange(0, BLOCK_SIZE) + + bid = tl.program_id(axis=0) + seq_len = tl.load(seq_lens + bid) + io_mask = t_range < num_draft_tokens + mask_row = tl.load( + evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0 + ) + + num_trues = tl.sum(mask_row) + num_false = num_draft_tokens - num_trues + + start = (seq_len + num_false - 1) // page_size * page_size - seq_len + for i in range(max(start, 0), min(start + page_size, num_draft_tokens)): + tl.store(evict_mask + bid * num_draft_tokens + i, False) + + +@triton.jit +def get_target_cache_loc( + tgt_cache_loc, + to_free_slots, + accept_length, + to_free_num_slots, + out_cache_loc, + num_verify_tokens: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, + bs_upper: tl.constexpr, +): + bid = tl.program_id(axis=0) + offset = tl.arange(0, num_verify_tokens_upper) + bs_offset = tl.arange(0, bs_upper) + + # write the first part to tgt_cache_loc + accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + tgt_cache_loc_start = tl.sum(accept_len_all) + bid + copy_len = tl.load(accept_length + bid) + 1 + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len + ) + tl.store( + tgt_cache_loc + tgt_cache_loc_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + # write the second part to to_free_num_pages + to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) + to_free_num_slots_cur = tl.load(to_free_num_slots + bid) + out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur + to_free_slots_start = tl.sum(to_free_num_slots_all) + + copy_len = to_free_num_slots_cur + out_cache_loc_row = tl.load( + out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, + mask=offset < copy_len, + ) + tl.store( + to_free_slots + to_free_slots_start + offset, + out_cache_loc_row, + mask=offset < copy_len, + ) + + +@torch.compile(dynamic=True) +def get_src_tgt_cache_loc( + seq_lens: torch.Tensor, + out_cache_loc: torch.Tensor, + accept_index: torch.Tensor, + accept_length: torch.Tensor, + draft_token_num: int, + page_size: int, +): + src_cache_loc = out_cache_loc[accept_index] + tgt_cache_loc = torch.empty_like(src_cache_loc) + extended_len = seq_lens + draft_token_num + keep_len = torch.minimum( + (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size, + extended_len, + ) + to_free_num_slots = extended_len - keep_len + return src_cache_loc, tgt_cache_loc, to_free_num_slots + + +@triton.jit +def filter_finished_cache_loc_kernel( + out_cache_loc, + tgt_cache_loc, + accept_length, + accept_length_filter, + bs_upper: tl.constexpr, + num_verify_tokens_upper: tl.constexpr, +): + bid = tl.program_id(0) + bs_offset = tl.arange(0, bs_upper) + + accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) + old_start = tl.sum(accept_length_all) + bid + + accept_length_filter_all = tl.load( + accept_length_filter + bs_offset, mask=bs_offset < bid + ) + new_start = tl.sum(accept_length_filter_all) + + copy_len = tl.load(accept_length_filter + bid) + copy_offset = tl.arange(0, num_verify_tokens_upper) + value = tl.load( + tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len + ) + tl.store( + out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len + ) + + +@torch.compile(dynamic=True) +def create_accept_length_filter( + accept_length: torch.Tensor, + unfinished_index_device: torch.Tensor, + seq_lens: torch.Tensor, +): + accept_length_filter = torch.zeros_like(accept_length) + accept_length_filter[unfinished_index_device] = ( + accept_length[unfinished_index_device] + 1 + ) + seq_lens.add_(accept_length + 1) + return accept_length_filter + + +@torch.compile(dynamic=True) +def select_top_k_tokens( + i: int, + topk_p: torch.Tensor, + topk_index: torch.Tensor, + hidden_states: torch.Tensor, + scores: torch.Tensor, + topk: int, +): + if i == 0: + # The first step after extend + input_ids = topk_index.flatten() + hidden_states = hidden_states.repeat_interleave(topk, dim=0) + scores = topk_p # shape: (b, topk) + + tree_info = ( + topk_p.unsqueeze(1), # shape: (b, 1, topk) + topk_index, # shape: (b, topk) + torch.arange(-1, topk, dtype=torch.long, device="cuda") + .unsqueeze(0) + .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) + ) + else: + # The later decode steps + expand_scores = torch.mul( + scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) + ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) + topk_cs_p, topk_cs_index = fast_topk( + expand_scores.flatten(start_dim=1), topk, dim=-1 + ) # (b, topk) + scores = topk_cs_p # shape: (b, topk) + + topk_index = topk_index.reshape(-1, topk**2) + input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() + + if hidden_states.shape[0] > 0: + selected_input_index = topk_cs_index.flatten() // topk + torch.arange( + 0, hidden_states.shape[0], step=topk, device="cuda" + ).repeat_interleave(topk) + hidden_states = hidden_states[selected_input_index, :] + + tree_info = ( + expand_scores, # shape: (b, topk, topk) + topk_index, # shape: (b, topk * topk) + topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) + ) + + return input_ids, hidden_states, scores, tree_info + + +def _generate_simulated_accept_index( + accept_index, + predict, + accept_length, + bs, + spec_steps, + simulate_acc_len: float = SIMULATE_ACC_LEN, + simulate_acc_method: str = SIMULATE_ACC_METHOD, +): + assert simulate_acc_len > 0.0 + + if simulate_acc_method == "multinomial": + simulated_values = torch.normal( + mean=simulate_acc_len, + std=1.0, + size=(1,), + device="cpu", + ) + # clamp simulated values to be between 1 and self.spec_steps + simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) + simulate_acc_len = int(simulated_values.round().item()) + elif simulate_acc_method == "match-expected": + # multinomial sampling does not match the expected length + # we keep it for the sake of compatibility of existing tests + # but it's better to use "match-expected" for the cases that need to + # match the expected length, One caveat is that this will only sample + # either round down or round up of the expected length + simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len)) + lower = int(simulate_acc_len // 1) + upper = lower + 1 if lower < spec_steps + 1 else lower + if lower == upper: + simulate_acc_len = lower + else: + weight_upper = simulate_acc_len - lower + weight_lower = 1.0 - weight_upper + probs = torch.tensor([weight_lower, weight_upper], device="cpu") + sampled_index = torch.multinomial(probs, num_samples=1) + simulate_acc_len = lower if sampled_index == 0 else upper + else: + raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}") + + accept_indx_first_col = accept_index[:, 0].view(-1, 1) + sim_accept_index = torch.full( + (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda" + ) + sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange( + simulate_acc_len, device=accept_index.device + ) + accept_length.fill_(simulate_acc_len - 1) + predict.fill_(100) # some legit token id + return sim_accept_index + + +def traverse_tree( + retrieve_next_token: torch.Tensor, + retrieve_next_sibling: torch.Tensor, + draft_tokens: torch.Tensor, + grammar: BaseGrammarObject, + allocate_token_bitmask: torch.Tensor, +): + """ + Traverse the tree constructed by the draft model to generate the logits mask. + """ + assert ( + retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape + ) + + allocate_token_bitmask.fill_(0) + + def dfs( + curr: int, + retrieve_next_token: torch.Tensor, + retrieve_next_sibling: torch.Tensor, + parent_pos: int, + ): + if curr == 0: + # the first token generated by the target model, and thus it is always + # accepted from the previous iteration + accepted = True + else: + parent_bitmask = allocate_token_bitmask[parent_pos] + curr_token_id = draft_tokens[curr] + # 32 boolean bitmask values are packed into 32-bit integers + accepted = ( + parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) + ) != 0 + + if accepted: + if curr != 0: + # Accept the current token + grammar.accept_token(draft_tokens[curr]) + if not grammar.is_terminated(): + # Generate the bitmask for the current token + grammar.fill_vocab_mask(allocate_token_bitmask, curr) + if retrieve_next_token[curr] != -1: + # Visit the child node + dfs( + retrieve_next_token[curr], + retrieve_next_token, + retrieve_next_sibling, + curr, + ) + + if curr != 0: + # Rollback the current token + grammar.rollback(1) + + if retrieve_next_sibling[curr] != -1: + # Visit the sibling node + dfs( + retrieve_next_sibling[curr], + retrieve_next_token, + retrieve_next_sibling, + parent_pos, + ) + + dfs(0, retrieve_next_token, retrieve_next_sibling, -1) + + +def generate_token_bitmask( + reqs: List[Req], + verify_input: EagleVerifyInput, + retrieve_next_token_cpu: torch.Tensor, + retrieve_next_sibling_cpu: torch.Tensor, + draft_tokens_cpu: torch.Tensor, + vocab_size: int, +): + """ + Generate the logit mask for structured output. + Draft model's token can be either valid or invalid with respect to the grammar. + We need to perform DFS to + 1. figure out which tokens are accepted by the grammar. + 2. if so, what is the corresponding logit mask. + """ + + num_draft_tokens = draft_tokens_cpu.shape[-1] + + allocate_token_bitmask = None + assert len(reqs) == retrieve_next_token_cpu.shape[0] + grammar = None + for i, req in enumerate(reqs): + if req.grammar is not None: + if allocate_token_bitmask is None: + allocate_token_bitmask = req.grammar.allocate_vocab_mask( + vocab_size=vocab_size, + batch_size=draft_tokens_cpu.numel(), + device="cpu", + ) + grammar = req.grammar + s = time.perf_counter() + traverse_tree( + retrieve_next_token_cpu[i], + retrieve_next_sibling_cpu[i], + draft_tokens_cpu[i], + req.grammar, + allocate_token_bitmask[ + i * num_draft_tokens : (i + 1) * num_draft_tokens + ], + ) + tree_traverse_time = time.perf_counter() - s + if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD: + logger.warning( + f"Bit mask generation took {tree_traverse_time} seconds with " + f"grammar: {req.grammar}" + ) + + verify_input.grammar = grammar + return allocate_token_bitmask diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 82717b382..d67636aa4 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ) from sglang.srt.operations import execute_operations, execute_overlapped_operations from sglang.srt.operations_strategy import OperationsStrategy -from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput +from sglang.srt.speculative.spec_info import SpecInput from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip if TYPE_CHECKING: @@ -48,7 +49,7 @@ logger = logging.getLogger(__name__) def get_token_num_per_seq( forward_mode: ForwardMode, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None, + spec_info: Optional[SpecInput] = None, ): if forward_mode.is_target_verify(): return spec_info.draft_token_num @@ -273,7 +274,7 @@ def compute_split_token_index( def compute_split_indices_for_cuda_graph_replay( forward_mode: ForwardMode, cuda_graph_num_tokens: int, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): forward_mode_for_tbo_split = ( forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE @@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin: forward_mode: ForwardMode, bs: int, num_token_non_padded: int, - spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]], + spec_info: Optional[SpecInput], ): token_num_per_seq = get_token_num_per_seq( forward_mode=forward_mode, spec_info=spec_info diff --git a/test/srt/test_forward_split_prefill.py b/test/srt/test_forward_split_prefill.py index bbd247583..3100c8d00 100644 --- a/test/srt/test_forward_split_prefill.py +++ b/test/srt/test_forward_split_prefill.py @@ -7,7 +7,6 @@ or python3 test_forward_split_prefill.py """ -import time import unittest import numpy as np @@ -16,7 +15,7 @@ import torch from sglang.srt.configs.model_config import ModelConfig from sglang.srt.hf_transformers_utils import get_tokenizer from sglang.srt.managers.schedule_batch import Req, ScheduleBatch -from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.sampling.sampling_params import SamplingParams from sglang.srt.server_args import PortArgs, ServerArgs