Organize spec-related data structures (#10735)
This commit is contained in:
@@ -37,7 +37,7 @@ except ImportError:
|
||||
|
||||
IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
|
||||
|
||||
# Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
|
||||
# Env var was set in sglang.srt.server_args.ServerArgs.__post_init__
|
||||
DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -157,7 +157,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
||||
hidden_states = torch.stack(hidden_states_list, dim=0).to(self.device)
|
||||
|
||||
# local import to avoid circular import
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
|
||||
spec_info = EagleDraftInput(
|
||||
topk_p=topk_p,
|
||||
|
||||
@@ -4,18 +4,13 @@ from __future__ import annotations
|
||||
end to end attention solution with aiter kernels
|
||||
"""
|
||||
|
||||
import math
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
try:
|
||||
from aiter import (
|
||||
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
qo_indptr = None
|
||||
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -888,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -900,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
|
||||
seq_lens_sum: int,
|
||||
prefix_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
|
||||
kv_start_idx = None
|
||||
@@ -984,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
||||
extend_lens: torch.Tensor,
|
||||
max_q_len: int,
|
||||
max_kv_len: int,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -997,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
||||
extend_lens: torch.Tensor,
|
||||
max_q_len: int,
|
||||
max_kv_len: int,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
|
||||
@@ -1054,7 +1049,7 @@ class AiterMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -13,7 +12,8 @@ from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import get_bool_env_var
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -127,7 +127,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
metadata = ForwardMetadata()
|
||||
|
||||
@@ -147,7 +147,7 @@ class AscendAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
metadata = self.graph_metadata[bs]
|
||||
|
||||
@@ -8,7 +8,7 @@ import torch
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
|
||||
class AttentionBackend(ABC):
|
||||
@@ -31,7 +31,7 @@ class AttentionBackend(ABC):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
||||
raise NotImplementedError()
|
||||
@@ -44,7 +44,7 @@ class AttentionBackend(ABC):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
||||
|
||||
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
if _is_cuda:
|
||||
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
|
||||
|
||||
@@ -11,9 +11,8 @@ import triton.language as tl
|
||||
from sglang.srt.configs.model_config import AttentionArch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import SWAKVPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -1487,7 +1486,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Initialize forward metadata for capturing CUDA graph."""
|
||||
metadata = FlashAttentionMetadata()
|
||||
@@ -1722,7 +1721,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
out_cache_loc: Optional[torch.Tensor] = None,
|
||||
):
|
||||
@@ -2340,7 +2339,7 @@ class FlashAttentionMultiStepBackend:
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
@@ -2357,7 +2356,7 @@ class FlashAttentionMultiStepBackend:
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
# TODO: incrementally update the metadata for the later steps,
|
||||
|
||||
@@ -28,8 +28,8 @@ from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
get_int_env_var,
|
||||
is_flashinfer_available,
|
||||
@@ -344,7 +344,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -451,7 +451,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -669,7 +669,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -684,7 +684,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -710,7 +710,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -760,7 +760,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
disable_split_kv: Optional[bool] = None,
|
||||
):
|
||||
@@ -794,7 +794,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
@@ -905,7 +905,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
@@ -921,7 +921,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
@@ -959,7 +959,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
@@ -1006,7 +1006,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
@@ -1049,7 +1049,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
fixed_split_size: Optional[int] = None,
|
||||
):
|
||||
@@ -1077,9 +1077,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(
|
||||
spec_info, (EagleDraftInput, EagleVerifyInput, NgramVerifyInput)
|
||||
)
|
||||
assert isinstance(spec_info, SpecInput)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
@@ -1138,7 +1136,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
@@ -1202,7 +1200,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
||||
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
||||
|
||||
@@ -30,7 +30,7 @@ from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_sm100_supported,
|
||||
@@ -40,7 +40,7 @@ from sglang.srt.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
@@ -361,7 +361,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
@@ -441,7 +441,7 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -663,7 +663,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrapper: BatchMLAPagedAttentionWrapper,
|
||||
init_metadata_replay: bool = False,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
spec_info: Optional[SpecInput] = None,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
decode_wrapper = decode_wrapper or self.decode_wrapper
|
||||
@@ -688,7 +688,7 @@ class FlashInferMLAIndicesUpdaterDecode:
|
||||
q_indptr: torch.Tensor,
|
||||
kv_indptr: torch.Tensor,
|
||||
init_metadata_replay: bool = False,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
spec_info: Optional[SpecInput] = None,
|
||||
**fast_decode_kwargs,
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
@@ -776,7 +776,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
prefix_lens: torch.Tensor,
|
||||
prefill_wrapper_paged: BatchMLAPagedAttentionWrapper,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
spec_info: Optional[SpecInput] = None,
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -811,7 +811,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
spec_info: Optional[SpecInput] = None,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
sm_scale = self.scaling
|
||||
@@ -838,9 +838,7 @@ class FlashInferMLAIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
)
|
||||
assert isinstance(spec_info, SpecInput)
|
||||
# TODO: Support topk > 1 with custom mask
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
@@ -894,7 +892,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
|
||||
if topk > 1:
|
||||
raise ValueError(
|
||||
@@ -963,7 +961,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
@@ -983,8 +981,6 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
|
||||
# FlashMLA only supports pagesize=64
|
||||
@@ -187,7 +187,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
|
||||
@@ -257,7 +257,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
|
||||
class HybridAttnBackend(AttentionBackend):
|
||||
@@ -71,7 +71,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
backend = self._select_backend(forward_mode)
|
||||
backend.init_forward_metadata_capture_cuda_graph(
|
||||
@@ -92,7 +92,7 @@ class HybridAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
backend = self._select_backend(forward_mode)
|
||||
|
||||
@@ -21,8 +21,8 @@ from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.models.qwen3_next import Qwen3HybridLinearDecoderLayer, fused_gdn_gating
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.models.qwen3_next import fused_gdn_gating
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import is_cuda, is_npu
|
||||
|
||||
if is_cuda():
|
||||
@@ -134,7 +134,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.query_start_loc_list[bs - 1].copy_(
|
||||
@@ -161,7 +161,7 @@ class MambaAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
num_padding = torch.count_nonzero(
|
||||
@@ -451,7 +451,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
attn_backend.init_forward_metadata_capture_cuda_graph(
|
||||
@@ -472,7 +472,7 @@ class HybridLinearAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
for attn_backend in self.attn_backend_list:
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt import two_batch_overlap
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
@@ -46,7 +46,7 @@ class TboAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
self.primary.init_forward_metadata_capture_cuda_graph(
|
||||
bs=bs,
|
||||
@@ -77,7 +77,7 @@ class TboAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.primary.init_forward_metadata_replay_cuda_graph(
|
||||
@@ -112,7 +112,7 @@ class TboAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
@@ -196,7 +196,7 @@ def _init_forward_metadata_cuda_graph_split(
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: "ForwardMode",
|
||||
spec_info: Optional[EagleVerifyInput],
|
||||
spec_info: Optional[SpecInput],
|
||||
# capture args
|
||||
capture_num_tokens: int = None,
|
||||
# replay args
|
||||
|
||||
@@ -22,7 +22,7 @@ from sglang.srt.utils import (
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
|
||||
def logit_capping_mod(logit_capping_method, logit_cap):
|
||||
@@ -482,7 +482,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
window_kv_indptr = self.window_kv_indptr
|
||||
@@ -638,7 +638,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
@@ -883,7 +883,7 @@ class TritonMultiStepDraftBackend:
|
||||
topk: int,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import generate_draft_decode_kv_indices
|
||||
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
||||
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
|
||||
@@ -20,12 +20,10 @@ from sglang.srt.utils import is_flashinfer_available
|
||||
if is_flashinfer_available():
|
||||
import flashinfer
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
# Constants
|
||||
DEFAULT_WORKSPACE_SIZE_MB = (
|
||||
@@ -201,7 +199,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
metadata = TRTLLMMHAMetadata()
|
||||
@@ -314,7 +312,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
@@ -661,7 +659,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
||||
forward_batch: ForwardBatch,
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
|
||||
@@ -678,7 +676,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
|
||||
self, forward_batch: ForwardBatch, bs: int
|
||||
):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info.is_draft_input()
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ if is_flashinfer_available():
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
|
||||
@@ -214,7 +214,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
"""Initialize metadata for CUDA graph capture."""
|
||||
|
||||
@@ -270,7 +270,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
"""Replay CUDA graph with new inputs."""
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Optional, Union
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -17,7 +17,7 @@ from sglang.srt.utils import get_bool_env_var, get_device_core_count
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -393,7 +393,7 @@ class WaveAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
|
||||
@@ -477,7 +477,7 @@ class WaveAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
|
||||
@@ -11,12 +11,8 @@ from sglang.srt.distributed import (
|
||||
get_moe_expert_parallel_world_size,
|
||||
get_moe_tensor_parallel_rank,
|
||||
get_moe_tensor_parallel_world_size,
|
||||
get_tp_group,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
||||
use_symmetric_memory,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.layers.moe import (
|
||||
MoeRunnerConfig,
|
||||
@@ -24,7 +20,6 @@ from sglang.srt.layers.moe import (
|
||||
should_use_flashinfer_trtllm_moe,
|
||||
)
|
||||
from sglang.srt.layers.moe.token_dispatcher.standard import (
|
||||
CombineInput,
|
||||
StandardDispatcher,
|
||||
StandardDispatchOutput,
|
||||
)
|
||||
|
||||
@@ -73,9 +73,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
@@ -957,9 +955,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput, NgramVerifyInput]] = (
|
||||
None
|
||||
)
|
||||
# spec_info: Optional[SpecInput] = None
|
||||
spec_info: Optional[SpecInput] = None
|
||||
|
||||
# Whether to return hidden states
|
||||
return_hidden_states: bool = False
|
||||
@@ -1995,9 +1992,9 @@ class ModelWorkerBatch:
|
||||
|
||||
# Speculative decoding
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput, NgramVerifyInput]] = (
|
||||
None
|
||||
)
|
||||
|
||||
spec_info: Optional[SpecInput] = None
|
||||
|
||||
# If set, the output of the batch contains the hidden states of the run.
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
hicache_consumer_index: int = -1
|
||||
|
||||
@@ -607,7 +607,7 @@ class CPUGraphRunner:
|
||||
def get_spec_info(self, num_tokens: int):
|
||||
spec_info = None
|
||||
if self.model_runner.spec_algorithm.is_eagle():
|
||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
if self.model_runner.is_draft_worker:
|
||||
raise RuntimeError("This should not happen.")
|
||||
|
||||
@@ -821,7 +821,7 @@ class CudaGraphRunner:
|
||||
self.model_runner.spec_algorithm.is_eagle()
|
||||
or self.model_runner.spec_algorithm.is_standalone()
|
||||
):
|
||||
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
if self.model_runner.is_draft_worker:
|
||||
raise RuntimeError("This should not happen.")
|
||||
|
||||
@@ -45,13 +45,7 @@ from sglang.srt.layers.dp_attention import (
|
||||
get_attention_tp_size,
|
||||
set_dp_buffer_len,
|
||||
)
|
||||
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
||||
from sglang.srt.utils import (
|
||||
flatten_nested_list,
|
||||
get_compiler_backend,
|
||||
is_npu,
|
||||
support_triton,
|
||||
)
|
||||
from sglang.srt.utils import get_compiler_backend, is_npu, support_triton
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
@@ -60,8 +54,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
||||
|
||||
_is_npu = is_npu()
|
||||
|
||||
@@ -293,7 +286,7 @@ class ForwardBatch:
|
||||
global_forward_mode: Optional[ForwardMode] = None
|
||||
|
||||
# Speculative decoding
|
||||
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
||||
spec_info: Optional[SpecInput] = None
|
||||
spec_algorithm: SpeculativeAlgorithm = None
|
||||
capture_hidden_mode: CaptureHiddenMode = None
|
||||
|
||||
@@ -364,33 +357,14 @@ class ForwardBatch:
|
||||
|
||||
# For MLP sync
|
||||
if batch.global_num_tokens is not None:
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
)
|
||||
|
||||
assert batch.global_num_tokens_for_logprob is not None
|
||||
|
||||
# process global_num_tokens and global_num_tokens_for_logprob
|
||||
if batch.spec_info is not None:
|
||||
if isinstance(batch.spec_info, EagleDraftInput):
|
||||
global_num_tokens = [
|
||||
x * batch.spec_info.num_tokens_per_batch
|
||||
for x in batch.global_num_tokens
|
||||
]
|
||||
global_num_tokens_for_logprob = [
|
||||
x * batch.spec_info.num_tokens_for_logprob_per_batch
|
||||
for x in batch.global_num_tokens_for_logprob
|
||||
]
|
||||
else:
|
||||
assert isinstance(batch.spec_info, EagleVerifyInput)
|
||||
global_num_tokens = [
|
||||
x * batch.spec_info.draft_token_num
|
||||
for x in batch.global_num_tokens
|
||||
]
|
||||
global_num_tokens_for_logprob = [
|
||||
x * batch.spec_info.draft_token_num
|
||||
for x in batch.global_num_tokens_for_logprob
|
||||
]
|
||||
spec_info: SpecInput = batch.spec_info
|
||||
global_num_tokens, global_num_tokens_for_logprob = (
|
||||
spec_info.get_spec_adjusted_global_num_tokens(batch)
|
||||
)
|
||||
else:
|
||||
global_num_tokens = batch.global_num_tokens
|
||||
global_num_tokens_for_logprob = batch.global_num_tokens_for_logprob
|
||||
@@ -669,9 +643,6 @@ class ForwardBatch:
|
||||
)
|
||||
|
||||
def prepare_mlp_sync_batch(self, model_runner: ModelRunner):
|
||||
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
assert self.global_num_tokens_cpu is not None
|
||||
assert self.global_num_tokens_for_logprob_cpu is not None
|
||||
|
||||
@@ -768,7 +739,8 @@ class ForwardBatch:
|
||||
if self.extend_seq_lens is not None:
|
||||
self.extend_seq_lens = self._pad_tensor_to_size(self.extend_seq_lens, bs)
|
||||
|
||||
if self.spec_info is not None and isinstance(self.spec_info, EagleDraftInput):
|
||||
if self.spec_info is not None and self.spec_info.is_draft_input():
|
||||
# FIXME(lsyin): remove this isinstance logic
|
||||
spec_info = self.spec_info
|
||||
self.output_cache_loc_backup = self.out_cache_loc
|
||||
self.hidden_states_backup = spec_info.hidden_states
|
||||
|
||||
@@ -20,7 +20,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
|
||||
@@ -21,7 +21,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
ForwardBatch,
|
||||
ForwardMode,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput
|
||||
from sglang.srt.speculative.spec_utils import fast_topk
|
||||
from sglang.srt.utils import (
|
||||
require_attn_tp_gather,
|
||||
require_gathered_buffer,
|
||||
|
||||
@@ -1,235 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
get_last_loc,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
SIMULATE_ACC_LEN,
|
||||
TREE_SPEC_KERNEL_AVAILABLE,
|
||||
_generate_simulated_accept_index,
|
||||
align_evict_mask_to_page_size,
|
||||
assign_req_to_token_pool,
|
||||
create_accept_length_filter,
|
||||
create_extend_after_decode_spec_info,
|
||||
filter_finished_cache_loc_kernel,
|
||||
get_src_tgt_cache_loc,
|
||||
get_target_cache_loc,
|
||||
)
|
||||
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
fast_topk,
|
||||
top_k_renorm_prob,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
elif is_hip():
|
||||
from sgl_kernel import fast_topk, verify_tree_greedy
|
||||
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Simulate acceptance length for benchmarking purposes
|
||||
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
|
||||
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
|
||||
|
||||
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||
|
||||
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleDraftInput:
|
||||
# The inputs for decode
|
||||
# shape: (b, topk)
|
||||
topk_p: torch.Tensor = None
|
||||
topk_index: torch.Tensor = None
|
||||
# shape: (b, hidden_size)
|
||||
hidden_states: torch.Tensor = None
|
||||
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
|
||||
|
||||
# Inputs for extend
|
||||
# shape: (b,)
|
||||
verified_id: torch.Tensor = None
|
||||
accept_length: torch.Tensor = None
|
||||
accept_length_cpu: List[int] = None
|
||||
|
||||
# Inputs for the attention backends
|
||||
# shape: (b + 1,)
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
# Shape info for padding
|
||||
num_tokens_per_batch: int = -1
|
||||
num_tokens_for_logprob_per_batch: int = -1
|
||||
|
||||
# Inputs for draft extend
|
||||
# shape: (b,)
|
||||
seq_lens_for_draft_extend: torch.Tensor = None
|
||||
req_pool_indices_for_draft_extend: torch.Tensor = None
|
||||
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
@classmethod
|
||||
def create_idle_input(
|
||||
cls,
|
||||
device: torch.device,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
topk: int,
|
||||
capture_hidden_mode: CaptureHiddenMode,
|
||||
):
|
||||
return cls(
|
||||
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
||||
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||
capture_hidden_mode=capture_hidden_mode,
|
||||
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
||||
accept_length_cpu=[],
|
||||
)
|
||||
|
||||
def prepare_extend_after_decode(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
batch.input_ids = self.verified_id
|
||||
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
batch.return_logprob = False
|
||||
batch.return_hidden_states = False
|
||||
|
||||
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.accept_length.add_(1)
|
||||
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
|
||||
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||
|
||||
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
|
||||
batch.input_ids,
|
||||
batch.seq_lens,
|
||||
self.accept_length,
|
||||
self.positions,
|
||||
self.verified_id,
|
||||
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
|
||||
)
|
||||
|
||||
def generate_attn_arg_prefill(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
bs = self.accept_length.numel()
|
||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
if paged_kernel_lens_sum is None:
|
||||
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
||||
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
cum_kv_seq_len,
|
||||
None,
|
||||
kv_indices,
|
||||
req_to_token.size(1),
|
||||
)
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||
if has_been_filtered:
|
||||
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
||||
# therefore, we don't need to filter the batch again in scheduler
|
||||
if len(new_indices) != len(self.topk_p):
|
||||
logger.warning(
|
||||
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
|
||||
)
|
||||
self.topk_p = self.topk_p[: len(new_indices)]
|
||||
self.topk_index = self.topk_index[: len(new_indices)]
|
||||
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||
self.verified_id = self.verified_id[: len(new_indices)]
|
||||
else:
|
||||
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
|
||||
self.topk_p = self.topk_p[new_indices]
|
||||
self.topk_index = self.topk_index[new_indices]
|
||||
self.hidden_states = self.hidden_states[new_indices]
|
||||
self.verified_id = self.verified_id[new_indices]
|
||||
|
||||
def merge_batch(self, spec_info: EagleDraftInput):
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = spec_info.hidden_states
|
||||
self.verified_id = spec_info.verified_id
|
||||
self.topk_p = spec_info.topk_p
|
||||
self.topk_index = spec_info.topk_index
|
||||
return
|
||||
if spec_info.hidden_states is None:
|
||||
return
|
||||
self.hidden_states = torch.cat(
|
||||
[self.hidden_states, spec_info.hidden_states], axis=0
|
||||
)
|
||||
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
||||
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
|
||||
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyOutput:
|
||||
# Draft input batch
|
||||
draft_input: EagleDraftInput
|
||||
# Logit outputs from target worker
|
||||
logits_output: LogitsProcessorOutput
|
||||
# Accepted token ids including the bonus token
|
||||
verified_id: torch.Tensor
|
||||
# Accepted token length per sequence in a batch in CPU.
|
||||
accept_length_per_req_cpu: List[int]
|
||||
# Accepted indices from logits_output.next_token_logits
|
||||
accepted_indices: torch.Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyInput:
|
||||
class EagleVerifyInput(SpecInput):
|
||||
draft_token: torch.Tensor
|
||||
custom_mask: torch.Tensor
|
||||
positions: torch.Tensor
|
||||
@@ -245,6 +62,12 @@ class EagleVerifyInput:
|
||||
seq_lens_cpu: torch.Tensor
|
||||
grammar: BaseGrammarObject = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__init__(SpecInputType.EAGLE_VERIFY)
|
||||
|
||||
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
||||
return self.draft_token_num, self.draft_token_num
|
||||
|
||||
@classmethod
|
||||
def create_idle_input(cls, topk: int, spec_steps: int, num_verify_tokens: int):
|
||||
return cls(
|
||||
@@ -724,574 +547,184 @@ class EagleVerifyInput:
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_extend_after_decode_spec_info(
|
||||
verified_id,
|
||||
seq_lens,
|
||||
accept_lens,
|
||||
positions,
|
||||
new_verified_id,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = tl.arange(0, bs_upper)
|
||||
seq_length = tl.load(seq_lens + pid)
|
||||
accept_length = tl.load(accept_lens + pid)
|
||||
@dataclass
|
||||
class EagleDraftInput(SpecInput):
|
||||
# The inputs for decode
|
||||
# shape: (b, topk)
|
||||
topk_p: torch.Tensor = None
|
||||
topk_index: torch.Tensor = None
|
||||
# shape: (b, hidden_size)
|
||||
hidden_states: torch.Tensor = None
|
||||
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.FULL
|
||||
|
||||
accept_len_cumsum = tl.sum(
|
||||
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
|
||||
)
|
||||
positions_ptr = positions + accept_len_cumsum
|
||||
mask = offsets < accept_length
|
||||
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
|
||||
# Inputs for extend
|
||||
# shape: (b,)
|
||||
verified_id: torch.Tensor = None
|
||||
accept_length: torch.Tensor = None
|
||||
accept_length_cpu: List[int] = None
|
||||
|
||||
accept_len_cumsum += accept_length - 1
|
||||
verified_id_data = tl.load(verified_id + accept_len_cumsum)
|
||||
tl.store(new_verified_id + pid, verified_id_data)
|
||||
# Inputs for the attention backends
|
||||
# shape: (b + 1,)
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
# Shape info for padding
|
||||
num_tokens_per_batch: int = -1
|
||||
num_tokens_for_logprob_per_batch: int = -1
|
||||
|
||||
@triton.jit
|
||||
def assign_req_to_token_pool(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
start_offset,
|
||||
end_offset,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 32
|
||||
pid = tl.program_id(axis=0)
|
||||
kv_start = tl.load(start_offset + pid)
|
||||
kv_end = tl.load(end_offset + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
# Inputs for draft extend
|
||||
# shape: (b,)
|
||||
seq_lens_for_draft_extend: torch.Tensor = None
|
||||
req_pool_indices_for_draft_extend: torch.Tensor = None
|
||||
|
||||
length_offset = tl.arange(0, bs_upper)
|
||||
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
out_offset = tl.sum(end - start, axis=0)
|
||||
def __post_init__(self):
|
||||
super().__init__(SpecInputType.EAGLE_DRAFT)
|
||||
|
||||
out_cache_ptr = out_cache_loc + out_offset
|
||||
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
||||
return self.num_tokens_per_batch, self.num_tokens_for_logprob_per_batch
|
||||
|
||||
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
|
||||
load_offset = tl.arange(0, BLOCK_SIZE)
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = save_offset < kv_end
|
||||
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
||||
tl.store(token_pool + save_offset, data, mask=mask)
|
||||
save_offset += BLOCK_SIZE
|
||||
load_offset += BLOCK_SIZE
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
|
||||
@triton.jit
|
||||
def assign_draft_cache_locs(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
num_new_pages_per_topk,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
speculative_num_steps: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
pid = tl.program_id(axis=0)
|
||||
pt = 0
|
||||
for i, extend_len in enumerate(batch.extend_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.cat(
|
||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||
)
|
||||
pt += extend_len
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
copy_len = topk * speculative_num_steps
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
else:
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
copy_len = tl.load(extend_lens + pid)
|
||||
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
||||
out_cache_ptr = out_cache_loc + cum_copy_len
|
||||
|
||||
# Part 1: Copy from out_cache_loc to req_to_token
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = copy_offset < copy_len
|
||||
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
||||
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
return
|
||||
|
||||
# Part 2: Copy the indices for the last partial page
|
||||
prefix_len = tl.load(seq_lens + pid)
|
||||
last_page_len = prefix_len % page_size
|
||||
offsets = tl.arange(0, page_size)
|
||||
mask = offsets < last_page_len
|
||||
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
||||
prefix_base = token_pool + prefix_len - last_page_len
|
||||
|
||||
for topk_id in range(topk):
|
||||
value = tl.load(prefix_base + offsets, mask=mask)
|
||||
tl.store(
|
||||
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Part 3: Remove the padding in out_cache_loc
|
||||
iter_offest = tl.arange(0, iter_upper)
|
||||
for topk_id in range(topk):
|
||||
indices = tl.load(
|
||||
prefix_base
|
||||
+ topk_id * num_new_pages_per_topk_ * page_size
|
||||
+ last_page_len
|
||||
+ iter_offest,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc
|
||||
+ pid * topk * speculative_num_steps
|
||||
+ topk_id * speculative_num_steps
|
||||
+ iter_offest,
|
||||
indices,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def generate_draft_decode_kv_indices(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
paged_kernel_lens,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
positions,
|
||||
pool_len: tl.constexpr,
|
||||
kv_indices_stride: tl.constexpr,
|
||||
kv_indptr_stride: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
num_tokens_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
iters = tl.program_id(axis=0)
|
||||
bid = tl.program_id(axis=1)
|
||||
topk_id = tl.program_id(axis=2)
|
||||
|
||||
num_steps = tl.num_programs(axis=0)
|
||||
num_seqs = tl.num_programs(axis=1)
|
||||
topk = tl.num_programs(axis=2)
|
||||
|
||||
kv_indices += kv_indices_stride * iters
|
||||
kv_indptr += kv_indptr_stride * iters
|
||||
iters += 1
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
|
||||
seq_len = tl.load(paged_kernel_lens + bid)
|
||||
cum_seq_len = tl.sum(seq_lens)
|
||||
|
||||
# Update kv_indices
|
||||
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
||||
kv_ptr = kv_indices + kv_offset
|
||||
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
||||
|
||||
kv_offset = tl.arange(0, BLOCK_SIZE)
|
||||
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = kv_offset < seq_len
|
||||
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
|
||||
tl.store(kv_ptr + kv_offset, data, mask=mask)
|
||||
kv_offset += BLOCK_SIZE
|
||||
|
||||
extend_offset = tl.arange(0, iter_upper)
|
||||
if page_size == 1 or topk == 1:
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
else:
|
||||
prefix_len = seq_len
|
||||
last_page_len = prefix_len % page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_len + num_steps + page_size - 1
|
||||
) // page_size
|
||||
prefix_base = seq_len // page_size * page_size
|
||||
start = (
|
||||
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
||||
)
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + start + extend_offset,
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
|
||||
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||
|
||||
# Update kv_indptr
|
||||
bs_offset = tl.arange(0, num_tokens_upper)
|
||||
|
||||
zid = bid * topk + topk_id
|
||||
if zid == 0:
|
||||
zid = num_seqs * topk
|
||||
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
|
||||
base = tl.sum(positions)
|
||||
tl.store(kv_indptr + zid, base + zid * iters)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def align_evict_mask_to_page_size(
|
||||
seq_lens,
|
||||
evict_mask,
|
||||
page_size: tl.constexpr,
|
||||
num_draft_tokens: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
t_range = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
bid = tl.program_id(axis=0)
|
||||
seq_len = tl.load(seq_lens + bid)
|
||||
io_mask = t_range < num_draft_tokens
|
||||
mask_row = tl.load(
|
||||
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
|
||||
)
|
||||
|
||||
num_trues = tl.sum(mask_row)
|
||||
num_false = num_draft_tokens - num_trues
|
||||
|
||||
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
||||
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
||||
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_target_cache_loc(
|
||||
tgt_cache_loc,
|
||||
to_free_slots,
|
||||
accept_length,
|
||||
to_free_num_slots,
|
||||
out_cache_loc,
|
||||
num_verify_tokens: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(axis=0)
|
||||
offset = tl.arange(0, num_verify_tokens_upper)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
# write the first part to tgt_cache_loc
|
||||
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
||||
copy_len = tl.load(accept_length + bid) + 1
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
tgt_cache_loc + tgt_cache_loc_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
# write the second part to to_free_num_pages
|
||||
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
||||
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
||||
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
||||
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
||||
|
||||
copy_len = to_free_num_slots_cur
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
tl.store(
|
||||
to_free_slots + to_free_slots_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_src_tgt_cache_loc(
|
||||
seq_lens: torch.Tensor,
|
||||
out_cache_loc: torch.Tensor,
|
||||
accept_index: torch.Tensor,
|
||||
accept_length: torch.Tensor,
|
||||
draft_token_num: int,
|
||||
page_size: int,
|
||||
):
|
||||
src_cache_loc = out_cache_loc[accept_index]
|
||||
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
||||
extended_len = seq_lens + draft_token_num
|
||||
keep_len = torch.minimum(
|
||||
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
||||
extended_len,
|
||||
)
|
||||
to_free_num_slots = extended_len - keep_len
|
||||
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
||||
|
||||
|
||||
@triton.jit
|
||||
def filter_finished_cache_loc_kernel(
|
||||
out_cache_loc,
|
||||
tgt_cache_loc,
|
||||
accept_length,
|
||||
accept_length_filter,
|
||||
bs_upper: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(0)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
old_start = tl.sum(accept_length_all) + bid
|
||||
|
||||
accept_length_filter_all = tl.load(
|
||||
accept_length_filter + bs_offset, mask=bs_offset < bid
|
||||
)
|
||||
new_start = tl.sum(accept_length_filter_all)
|
||||
|
||||
copy_len = tl.load(accept_length_filter + bid)
|
||||
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
||||
value = tl.load(
|
||||
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def create_accept_length_filter(
|
||||
accept_length: torch.Tensor,
|
||||
unfinished_index_device: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
):
|
||||
accept_length_filter = torch.zeros_like(accept_length)
|
||||
accept_length_filter[unfinished_index_device] = (
|
||||
accept_length[unfinished_index_device] + 1
|
||||
)
|
||||
seq_lens.add_(accept_length + 1)
|
||||
return accept_length_filter
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens(
|
||||
i: int,
|
||||
topk_p: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
topk: int,
|
||||
):
|
||||
if i == 0:
|
||||
# The first step after extend
|
||||
input_ids = topk_index.flatten()
|
||||
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
|
||||
scores = topk_p # shape: (b, topk)
|
||||
|
||||
tree_info = (
|
||||
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
||||
topk_index, # shape: (b, topk)
|
||||
torch.arange(-1, topk, dtype=torch.long, device="cuda")
|
||||
.unsqueeze(0)
|
||||
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||
)
|
||||
else:
|
||||
# The later decode steps
|
||||
expand_scores = torch.mul(
|
||||
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||
topk_cs_p, topk_cs_index = fast_topk(
|
||||
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||
) # (b, topk)
|
||||
scores = topk_cs_p # shape: (b, topk)
|
||||
|
||||
topk_index = topk_index.reshape(-1, topk**2)
|
||||
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||
|
||||
if hidden_states.shape[0] > 0:
|
||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||
).repeat_interleave(topk)
|
||||
hidden_states = hidden_states[selected_input_index, :]
|
||||
|
||||
tree_info = (
|
||||
expand_scores, # shape: (b, topk, topk)
|
||||
topk_index, # shape: (b, topk * topk)
|
||||
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
|
||||
)
|
||||
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
accept_length,
|
||||
bs,
|
||||
spec_steps,
|
||||
simulate_acc_len: float = SIMULATE_ACC_LEN,
|
||||
simulate_acc_method: str = SIMULATE_ACC_METHOD,
|
||||
):
|
||||
assert simulate_acc_len > 0.0
|
||||
|
||||
if simulate_acc_method == "multinomial":
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
)
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
elif simulate_acc_method == "match-expected":
|
||||
# multinomial sampling does not match the expected length
|
||||
# we keep it for the sake of compatibility of existing tests
|
||||
# but it's better to use "match-expected" for the cases that need to
|
||||
# match the expected length, One caveat is that this will only sample
|
||||
# either round down or round up of the expected length
|
||||
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
|
||||
lower = int(simulate_acc_len // 1)
|
||||
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||
if lower == upper:
|
||||
simulate_acc_len = lower
|
||||
else:
|
||||
weight_upper = simulate_acc_len - lower
|
||||
weight_lower = 1.0 - weight_upper
|
||||
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
simulate_acc_len = lower if sampled_index == 0 else upper
|
||||
else:
|
||||
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
||||
|
||||
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||
sim_accept_index = torch.full(
|
||||
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
||||
simulate_acc_len, device=accept_index.device
|
||||
)
|
||||
accept_length.fill_(simulate_acc_len - 1)
|
||||
predict.fill_(100) # some legit token id
|
||||
return sim_accept_index
|
||||
|
||||
|
||||
def traverse_tree(
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
grammar: BaseGrammarObject,
|
||||
allocate_token_bitmask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Traverse the tree constructed by the draft model to generate the logits mask.
|
||||
"""
|
||||
assert (
|
||||
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
||||
)
|
||||
|
||||
allocate_token_bitmask.fill_(0)
|
||||
|
||||
def dfs(
|
||||
curr: int,
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
parent_pos: int,
|
||||
@classmethod
|
||||
def create_idle_input(
|
||||
cls,
|
||||
device: torch.device,
|
||||
hidden_size: int,
|
||||
dtype: torch.dtype,
|
||||
topk: int,
|
||||
capture_hidden_mode: CaptureHiddenMode,
|
||||
):
|
||||
if curr == 0:
|
||||
# the first token generated by the target model, and thus it is always
|
||||
# accepted from the previous iteration
|
||||
accepted = True
|
||||
else:
|
||||
parent_bitmask = allocate_token_bitmask[parent_pos]
|
||||
curr_token_id = draft_tokens[curr]
|
||||
# 32 boolean bitmask values are packed into 32-bit integers
|
||||
accepted = (
|
||||
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
||||
) != 0
|
||||
return cls(
|
||||
verified_id=torch.empty((0,), device=device, dtype=torch.int32),
|
||||
hidden_states=torch.empty((0, hidden_size), device=device, dtype=dtype),
|
||||
topk_p=torch.empty((0, topk), device=device, dtype=torch.float32),
|
||||
topk_index=torch.empty((0, topk), device=device, dtype=torch.int64),
|
||||
capture_hidden_mode=capture_hidden_mode,
|
||||
accept_length=torch.empty((0,), device=device, dtype=torch.int32),
|
||||
accept_length_cpu=[],
|
||||
)
|
||||
|
||||
if accepted:
|
||||
if curr != 0:
|
||||
# Accept the current token
|
||||
grammar.accept_token(draft_tokens[curr])
|
||||
if not grammar.is_terminated():
|
||||
# Generate the bitmask for the current token
|
||||
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
||||
if retrieve_next_token[curr] != -1:
|
||||
# Visit the child node
|
||||
dfs(
|
||||
retrieve_next_token[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
curr,
|
||||
)
|
||||
def prepare_extend_after_decode(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
speculative_num_steps: int,
|
||||
):
|
||||
|
||||
if curr != 0:
|
||||
# Rollback the current token
|
||||
grammar.rollback(1)
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
if retrieve_next_sibling[curr] != -1:
|
||||
# Visit the sibling node
|
||||
dfs(
|
||||
retrieve_next_sibling[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
parent_pos,
|
||||
)
|
||||
batch.input_ids = self.verified_id
|
||||
batch.extend_lens = [x + 1 for x in batch.spec_info.accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
batch.return_logprob = False
|
||||
batch.return_hidden_states = False
|
||||
|
||||
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
||||
self.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
self.accept_length.add_(1)
|
||||
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
|
||||
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||
|
||||
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
|
||||
batch.input_ids,
|
||||
batch.seq_lens,
|
||||
self.accept_length,
|
||||
self.positions,
|
||||
self.verified_id,
|
||||
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
|
||||
)
|
||||
|
||||
def generate_token_bitmask(
|
||||
reqs: List[Req],
|
||||
verify_input: EagleVerifyInput,
|
||||
retrieve_next_token_cpu: torch.Tensor,
|
||||
retrieve_next_sibling_cpu: torch.Tensor,
|
||||
draft_tokens_cpu: torch.Tensor,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Generate the logit mask for structured output.
|
||||
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||
We need to perform DFS to
|
||||
1. figure out which tokens are accepted by the grammar.
|
||||
2. if so, what is the corresponding logit mask.
|
||||
"""
|
||||
def generate_attn_arg_prefill(
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
bs = self.accept_length.numel()
|
||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||
if paged_kernel_lens_sum is None:
|
||||
paged_kernel_lens_sum = cum_kv_seq_len[-1]
|
||||
|
||||
allocate_token_bitmask = None
|
||||
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
||||
grammar = None
|
||||
for i, req in enumerate(reqs):
|
||||
if req.grammar is not None:
|
||||
if allocate_token_bitmask is None:
|
||||
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
||||
vocab_size=vocab_size,
|
||||
batch_size=draft_tokens_cpu.numel(),
|
||||
device="cpu",
|
||||
)
|
||||
grammar = req.grammar
|
||||
s = time.perf_counter()
|
||||
traverse_tree(
|
||||
retrieve_next_token_cpu[i],
|
||||
retrieve_next_sibling_cpu[i],
|
||||
draft_tokens_cpu[i],
|
||||
req.grammar,
|
||||
allocate_token_bitmask[
|
||||
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||
],
|
||||
)
|
||||
tree_traverse_time = time.perf_counter() - s
|
||||
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
req_to_token,
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
cum_kv_seq_len,
|
||||
None,
|
||||
kv_indices,
|
||||
req_to_token.size(1),
|
||||
)
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, None
|
||||
|
||||
def filter_batch(self, new_indices: torch.Tensor, has_been_filtered: bool = True):
|
||||
if has_been_filtered:
|
||||
# in eagle_utils.py:verify, we have already filtered the batch by `unfinished_index`
|
||||
# therefore, we don't need to filter the batch again in scheduler
|
||||
if len(new_indices) != len(self.topk_p):
|
||||
logger.warning(
|
||||
f"Bit mask generation took {tree_traverse_time} seconds with "
|
||||
f"grammar: {req.grammar}"
|
||||
f"length of new_indices: {len(new_indices)} != length of topk_p: {len(self.topk_p)}, this should not happen"
|
||||
)
|
||||
self.topk_p = self.topk_p[: len(new_indices)]
|
||||
self.topk_index = self.topk_index[: len(new_indices)]
|
||||
self.hidden_states = self.hidden_states[: len(new_indices)]
|
||||
self.verified_id = self.verified_id[: len(new_indices)]
|
||||
else:
|
||||
# in some cases(e.g draft_extend), we have not filtered the batch by `unfinished_index`
|
||||
self.topk_p = self.topk_p[new_indices]
|
||||
self.topk_index = self.topk_index[new_indices]
|
||||
self.hidden_states = self.hidden_states[new_indices]
|
||||
self.verified_id = self.verified_id[new_indices]
|
||||
|
||||
verify_input.grammar = grammar
|
||||
return allocate_token_bitmask
|
||||
def merge_batch(self, spec_info: "EagleDraftInput"):
|
||||
if self.hidden_states is None:
|
||||
self.hidden_states = spec_info.hidden_states
|
||||
self.verified_id = spec_info.verified_id
|
||||
self.topk_p = spec_info.topk_p
|
||||
self.topk_index = spec_info.topk_index
|
||||
return
|
||||
if spec_info.hidden_states is None:
|
||||
return
|
||||
self.hidden_states = torch.cat(
|
||||
[self.hidden_states, spec_info.hidden_states], axis=0
|
||||
)
|
||||
self.verified_id = torch.cat([self.verified_id, spec_info.verified_id], axis=0)
|
||||
self.topk_p = torch.cat([self.topk_p, spec_info.topk_p])
|
||||
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyOutput:
|
||||
# Draft input batch
|
||||
draft_input: EagleDraftInput
|
||||
# Logit outputs from target worker
|
||||
logits_output: LogitsProcessorOutput
|
||||
# Accepted token ids including the bonus token
|
||||
verified_id: torch.Tensor
|
||||
# Accepted token length per sequence in a batch in CPU.
|
||||
accept_length_per_req_cpu: List[int]
|
||||
# Accepted indices from logits_output.next_token_logits
|
||||
accepted_indices: torch.Tensor
|
||||
@@ -34,16 +34,18 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
from sglang.srt.speculative.eagle_draft_extend_cuda_graph_runner import (
|
||||
EAGLEDraftExtendCudaGraphRunner,
|
||||
)
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
from sglang.srt.speculative.eagle_info import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
generate_token_bitmask,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import (
|
||||
empty_context,
|
||||
get_available_gpu_memory,
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import logging
|
||||
from typing import Optional
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
@@ -13,6 +13,7 @@ from dataclasses import dataclass
|
||||
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.layers.sampler import apply_custom_logit_processor
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
@@ -21,10 +22,10 @@ from sglang.srt.managers.schedule_batch import (
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
from sglang.srt.speculative.spec_info import SpecInput, SpecInputType
|
||||
from sglang.srt.speculative.spec_utils import (
|
||||
TREE_SPEC_KERNEL_AVAILABLE,
|
||||
assign_req_to_token_pool,
|
||||
create_flashinfer_kv_indices_triton,
|
||||
get_src_tgt_cache_loc,
|
||||
get_target_cache_loc,
|
||||
)
|
||||
@@ -42,7 +43,7 @@ elif is_hip():
|
||||
|
||||
|
||||
@dataclass
|
||||
class NgramVerifyInput:
|
||||
class NgramVerifyInput(SpecInput):
|
||||
def __init__(
|
||||
self,
|
||||
draft_token: torch.Tensor,
|
||||
@@ -53,6 +54,7 @@ class NgramVerifyInput:
|
||||
retrive_next_sibling: torch.Tensor,
|
||||
draft_token_num: int,
|
||||
):
|
||||
super().__init__(SpecInputType.NGRAM_VERIFY)
|
||||
self.draft_token = draft_token
|
||||
self.custom_mask = tree_mask
|
||||
self.positions = positions
|
||||
@@ -62,6 +64,9 @@ class NgramVerifyInput:
|
||||
self.draft_token_num = draft_token_num
|
||||
self.device = self.custom_mask.device
|
||||
|
||||
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
||||
return self.draft_token_num, self.draft_token_num
|
||||
|
||||
def prepare_for_verify(self, batch: ScheduleBatch, page_size: int):
|
||||
if batch.forward_mode.is_idle():
|
||||
return
|
||||
|
||||
@@ -1,8 +1,5 @@
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
from typing import List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -15,7 +12,6 @@ from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.cpp_ngram.ngram_cache import NgramCache
|
||||
from sglang.srt.speculative.ngram_utils import NgramVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import broadcast_pyobj
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import IntEnum, auto
|
||||
from typing import List, Tuple
|
||||
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
||||
|
||||
|
||||
class SpeculativeAlgorithm(IntEnum):
|
||||
@@ -35,3 +39,41 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
if name is not None:
|
||||
name = name.upper()
|
||||
return name_map[name]
|
||||
|
||||
|
||||
class SpecInputType(IntEnum):
|
||||
# NOTE: introduce this to distinguish the SpecInput types of multiple algorithms when asserting in attention backends.
|
||||
# If all algorithms can share the same datastrucutre of draft_input and verify_input, consider simplify it
|
||||
EAGLE_DRAFT = auto()
|
||||
EAGLE_VERIFY = auto()
|
||||
NGRAM_VERIFY = auto()
|
||||
|
||||
|
||||
class SpecInput(ABC):
|
||||
def __init__(self, spec_input_type: SpecInputType):
|
||||
self.spec_input_type = spec_input_type
|
||||
|
||||
def is_draft_input(self) -> bool:
|
||||
# FIXME: remove this function which is only used for assertion
|
||||
# or use another variable name like `draft_input` to substitute `spec_info`
|
||||
return self.spec_input_type == SpecInputType.EAGLE_DRAFT
|
||||
|
||||
def is_verify_input(self) -> bool:
|
||||
return self.spec_input_type in {
|
||||
SpecInputType.EAGLE_VERIFY,
|
||||
SpecInputType.NGRAM_VERIFY,
|
||||
}
|
||||
|
||||
@abstractmethod
|
||||
def get_spec_adjust_token_coefficient(self) -> Tuple[int, int]:
|
||||
pass
|
||||
|
||||
def get_spec_adjusted_global_num_tokens(
|
||||
self, forward_batch: ModelWorkerBatch
|
||||
) -> Tuple[List[int], List[int]]:
|
||||
c1, c2 = self.get_spec_adjust_token_coefficient()
|
||||
global_num_tokens = [x * c1 for x in forward_batch.global_num_tokens]
|
||||
global_num_tokens_for_logprob = [
|
||||
x * c2 for x in forward_batch.global_num_tokens_for_logprob
|
||||
]
|
||||
return global_num_tokens, global_num_tokens_for_logprob
|
||||
|
||||
607
python/sglang/srt/speculative/spec_utils.py
Normal file
607
python/sglang/srt/speculative/spec_utils.py
Normal file
@@ -0,0 +1,607 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.environ import envs
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
from sglang.srt.utils import is_cuda, is_hip
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import fast_topk
|
||||
elif is_hip():
|
||||
from sgl_kernel import fast_topk
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.speculative.eagle_info import EagleVerifyInput
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Simulate acceptance length for benchmarking purposes
|
||||
SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0
|
||||
SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get()
|
||||
|
||||
TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly
|
||||
|
||||
TREE_SPEC_KERNEL_AVAILABLE = "tree_speculative_sampling_target_only" in globals()
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_extend_after_decode_spec_info(
|
||||
verified_id,
|
||||
seq_lens,
|
||||
accept_lens,
|
||||
positions,
|
||||
new_verified_id,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
offsets = tl.arange(0, bs_upper)
|
||||
seq_length = tl.load(seq_lens + pid)
|
||||
accept_length = tl.load(accept_lens + pid)
|
||||
|
||||
accept_len_cumsum = tl.sum(
|
||||
tl.load(accept_lens + offsets, mask=offsets < pid, other=0)
|
||||
)
|
||||
positions_ptr = positions + accept_len_cumsum
|
||||
mask = offsets < accept_length
|
||||
tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask)
|
||||
|
||||
accept_len_cumsum += accept_length - 1
|
||||
verified_id_data = tl.load(verified_id + accept_len_cumsum)
|
||||
tl.store(new_verified_id + pid, verified_id_data)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def assign_req_to_token_pool(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
start_offset,
|
||||
end_offset,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 32
|
||||
pid = tl.program_id(axis=0)
|
||||
kv_start = tl.load(start_offset + pid)
|
||||
kv_end = tl.load(end_offset + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
|
||||
length_offset = tl.arange(0, bs_upper)
|
||||
start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0)
|
||||
out_offset = tl.sum(end - start, axis=0)
|
||||
|
||||
out_cache_ptr = out_cache_loc + out_offset
|
||||
|
||||
save_offset = tl.arange(0, BLOCK_SIZE) + kv_start
|
||||
load_offset = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = save_offset < kv_end
|
||||
data = tl.load(out_cache_ptr + load_offset, mask=mask)
|
||||
tl.store(token_pool + save_offset, data, mask=mask)
|
||||
save_offset += BLOCK_SIZE
|
||||
load_offset += BLOCK_SIZE
|
||||
|
||||
|
||||
@triton.jit
|
||||
def assign_draft_cache_locs(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
seq_lens,
|
||||
extend_lens,
|
||||
num_new_pages_per_topk,
|
||||
out_cache_loc,
|
||||
pool_len: tl.constexpr,
|
||||
topk: tl.constexpr,
|
||||
speculative_num_steps: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
copy_len = topk * speculative_num_steps
|
||||
out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps
|
||||
else:
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
copy_len = tl.load(extend_lens + pid)
|
||||
cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid))
|
||||
out_cache_ptr = out_cache_loc + cum_copy_len
|
||||
|
||||
# Part 1: Copy from out_cache_loc to req_to_token
|
||||
kv_start = tl.load(seq_lens + pid)
|
||||
token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len
|
||||
num_loop = tl.cdiv(copy_len, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = copy_offset < copy_len
|
||||
data = tl.load(out_cache_ptr + copy_offset, mask=mask)
|
||||
tl.store(token_pool + kv_start + copy_offset, data, mask=mask)
|
||||
|
||||
if page_size == 1 or topk == 1:
|
||||
return
|
||||
|
||||
# Part 2: Copy the indices for the last partial page
|
||||
prefix_len = tl.load(seq_lens + pid)
|
||||
last_page_len = prefix_len % page_size
|
||||
offsets = tl.arange(0, page_size)
|
||||
mask = offsets < last_page_len
|
||||
num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid)
|
||||
prefix_base = token_pool + prefix_len - last_page_len
|
||||
|
||||
for topk_id in range(topk):
|
||||
value = tl.load(prefix_base + offsets, mask=mask)
|
||||
tl.store(
|
||||
prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets,
|
||||
value,
|
||||
mask=mask,
|
||||
)
|
||||
|
||||
# Part 3: Remove the padding in out_cache_loc
|
||||
iter_offest = tl.arange(0, iter_upper)
|
||||
for topk_id in range(topk):
|
||||
indices = tl.load(
|
||||
prefix_base
|
||||
+ topk_id * num_new_pages_per_topk_ * page_size
|
||||
+ last_page_len
|
||||
+ iter_offest,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc
|
||||
+ pid * topk * speculative_num_steps
|
||||
+ topk_id * speculative_num_steps
|
||||
+ iter_offest,
|
||||
indices,
|
||||
mask=iter_offest < speculative_num_steps,
|
||||
)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def generate_draft_decode_kv_indices(
|
||||
req_pool_indices,
|
||||
req_to_token,
|
||||
paged_kernel_lens,
|
||||
kv_indices,
|
||||
kv_indptr,
|
||||
positions,
|
||||
pool_len: tl.constexpr,
|
||||
kv_indices_stride: tl.constexpr,
|
||||
kv_indptr_stride: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
iter_upper: tl.constexpr,
|
||||
num_tokens_upper: tl.constexpr,
|
||||
page_size: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 128
|
||||
iters = tl.program_id(axis=0)
|
||||
bid = tl.program_id(axis=1)
|
||||
topk_id = tl.program_id(axis=2)
|
||||
|
||||
num_steps = tl.num_programs(axis=0)
|
||||
num_seqs = tl.num_programs(axis=1)
|
||||
topk = tl.num_programs(axis=2)
|
||||
|
||||
kv_indices += kv_indices_stride * iters
|
||||
kv_indptr += kv_indptr_stride * iters
|
||||
iters += 1
|
||||
|
||||
load_offset = tl.arange(0, bs_upper)
|
||||
seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0)
|
||||
seq_len = tl.load(paged_kernel_lens + bid)
|
||||
cum_seq_len = tl.sum(seq_lens)
|
||||
|
||||
# Update kv_indices
|
||||
kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters)
|
||||
kv_ptr = kv_indices + kv_offset
|
||||
token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len
|
||||
|
||||
kv_offset = tl.arange(0, BLOCK_SIZE)
|
||||
num_loop = tl.cdiv(seq_len, BLOCK_SIZE)
|
||||
for _ in range(num_loop):
|
||||
mask = kv_offset < seq_len
|
||||
data = tl.load(token_pool_ptr + kv_offset, mask=mask)
|
||||
tl.store(kv_ptr + kv_offset, data, mask=mask)
|
||||
kv_offset += BLOCK_SIZE
|
||||
|
||||
extend_offset = tl.arange(0, iter_upper)
|
||||
if page_size == 1 or topk == 1:
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper),
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
else:
|
||||
prefix_len = seq_len
|
||||
last_page_len = prefix_len % page_size
|
||||
num_new_pages_per_topk = (
|
||||
last_page_len + num_steps + page_size - 1
|
||||
) // page_size
|
||||
prefix_base = seq_len // page_size * page_size
|
||||
start = (
|
||||
prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len
|
||||
)
|
||||
extend_data = tl.load(
|
||||
token_pool_ptr + start + extend_offset,
|
||||
mask=extend_offset < iters,
|
||||
)
|
||||
|
||||
tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters)
|
||||
|
||||
# Update kv_indptr
|
||||
bs_offset = tl.arange(0, num_tokens_upper)
|
||||
|
||||
zid = bid * topk + topk_id
|
||||
if zid == 0:
|
||||
zid = num_seqs * topk
|
||||
positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0)
|
||||
base = tl.sum(positions)
|
||||
tl.store(kv_indptr + zid, base + zid * iters)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def align_evict_mask_to_page_size(
|
||||
seq_lens,
|
||||
evict_mask,
|
||||
page_size: tl.constexpr,
|
||||
num_draft_tokens: tl.constexpr,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
t_range = tl.arange(0, BLOCK_SIZE)
|
||||
|
||||
bid = tl.program_id(axis=0)
|
||||
seq_len = tl.load(seq_lens + bid)
|
||||
io_mask = t_range < num_draft_tokens
|
||||
mask_row = tl.load(
|
||||
evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0
|
||||
)
|
||||
|
||||
num_trues = tl.sum(mask_row)
|
||||
num_false = num_draft_tokens - num_trues
|
||||
|
||||
start = (seq_len + num_false - 1) // page_size * page_size - seq_len
|
||||
for i in range(max(start, 0), min(start + page_size, num_draft_tokens)):
|
||||
tl.store(evict_mask + bid * num_draft_tokens + i, False)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def get_target_cache_loc(
|
||||
tgt_cache_loc,
|
||||
to_free_slots,
|
||||
accept_length,
|
||||
to_free_num_slots,
|
||||
out_cache_loc,
|
||||
num_verify_tokens: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
bs_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(axis=0)
|
||||
offset = tl.arange(0, num_verify_tokens_upper)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
# write the first part to tgt_cache_loc
|
||||
accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
tgt_cache_loc_start = tl.sum(accept_len_all) + bid
|
||||
copy_len = tl.load(accept_length + bid) + 1
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
tgt_cache_loc + tgt_cache_loc_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
# write the second part to to_free_num_pages
|
||||
to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid)
|
||||
to_free_num_slots_cur = tl.load(to_free_num_slots + bid)
|
||||
out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur
|
||||
to_free_slots_start = tl.sum(to_free_num_slots_all)
|
||||
|
||||
copy_len = to_free_num_slots_cur
|
||||
out_cache_loc_row = tl.load(
|
||||
out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
tl.store(
|
||||
to_free_slots + to_free_slots_start + offset,
|
||||
out_cache_loc_row,
|
||||
mask=offset < copy_len,
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def get_src_tgt_cache_loc(
|
||||
seq_lens: torch.Tensor,
|
||||
out_cache_loc: torch.Tensor,
|
||||
accept_index: torch.Tensor,
|
||||
accept_length: torch.Tensor,
|
||||
draft_token_num: int,
|
||||
page_size: int,
|
||||
):
|
||||
src_cache_loc = out_cache_loc[accept_index]
|
||||
tgt_cache_loc = torch.empty_like(src_cache_loc)
|
||||
extended_len = seq_lens + draft_token_num
|
||||
keep_len = torch.minimum(
|
||||
(seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size,
|
||||
extended_len,
|
||||
)
|
||||
to_free_num_slots = extended_len - keep_len
|
||||
return src_cache_loc, tgt_cache_loc, to_free_num_slots
|
||||
|
||||
|
||||
@triton.jit
|
||||
def filter_finished_cache_loc_kernel(
|
||||
out_cache_loc,
|
||||
tgt_cache_loc,
|
||||
accept_length,
|
||||
accept_length_filter,
|
||||
bs_upper: tl.constexpr,
|
||||
num_verify_tokens_upper: tl.constexpr,
|
||||
):
|
||||
bid = tl.program_id(0)
|
||||
bs_offset = tl.arange(0, bs_upper)
|
||||
|
||||
accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid)
|
||||
old_start = tl.sum(accept_length_all) + bid
|
||||
|
||||
accept_length_filter_all = tl.load(
|
||||
accept_length_filter + bs_offset, mask=bs_offset < bid
|
||||
)
|
||||
new_start = tl.sum(accept_length_filter_all)
|
||||
|
||||
copy_len = tl.load(accept_length_filter + bid)
|
||||
copy_offset = tl.arange(0, num_verify_tokens_upper)
|
||||
value = tl.load(
|
||||
tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len
|
||||
)
|
||||
tl.store(
|
||||
out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len
|
||||
)
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def create_accept_length_filter(
|
||||
accept_length: torch.Tensor,
|
||||
unfinished_index_device: torch.Tensor,
|
||||
seq_lens: torch.Tensor,
|
||||
):
|
||||
accept_length_filter = torch.zeros_like(accept_length)
|
||||
accept_length_filter[unfinished_index_device] = (
|
||||
accept_length[unfinished_index_device] + 1
|
||||
)
|
||||
seq_lens.add_(accept_length + 1)
|
||||
return accept_length_filter
|
||||
|
||||
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens(
|
||||
i: int,
|
||||
topk_p: torch.Tensor,
|
||||
topk_index: torch.Tensor,
|
||||
hidden_states: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
topk: int,
|
||||
):
|
||||
if i == 0:
|
||||
# The first step after extend
|
||||
input_ids = topk_index.flatten()
|
||||
hidden_states = hidden_states.repeat_interleave(topk, dim=0)
|
||||
scores = topk_p # shape: (b, topk)
|
||||
|
||||
tree_info = (
|
||||
topk_p.unsqueeze(1), # shape: (b, 1, topk)
|
||||
topk_index, # shape: (b, topk)
|
||||
torch.arange(-1, topk, dtype=torch.long, device="cuda")
|
||||
.unsqueeze(0)
|
||||
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||
)
|
||||
else:
|
||||
# The later decode steps
|
||||
expand_scores = torch.mul(
|
||||
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||
topk_cs_p, topk_cs_index = fast_topk(
|
||||
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||
) # (b, topk)
|
||||
scores = topk_cs_p # shape: (b, topk)
|
||||
|
||||
topk_index = topk_index.reshape(-1, topk**2)
|
||||
input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten()
|
||||
|
||||
if hidden_states.shape[0] > 0:
|
||||
selected_input_index = topk_cs_index.flatten() // topk + torch.arange(
|
||||
0, hidden_states.shape[0], step=topk, device="cuda"
|
||||
).repeat_interleave(topk)
|
||||
hidden_states = hidden_states[selected_input_index, :]
|
||||
|
||||
tree_info = (
|
||||
expand_scores, # shape: (b, topk, topk)
|
||||
topk_index, # shape: (b, topk * topk)
|
||||
topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk)
|
||||
)
|
||||
|
||||
return input_ids, hidden_states, scores, tree_info
|
||||
|
||||
|
||||
def _generate_simulated_accept_index(
|
||||
accept_index,
|
||||
predict,
|
||||
accept_length,
|
||||
bs,
|
||||
spec_steps,
|
||||
simulate_acc_len: float = SIMULATE_ACC_LEN,
|
||||
simulate_acc_method: str = SIMULATE_ACC_METHOD,
|
||||
):
|
||||
assert simulate_acc_len > 0.0
|
||||
|
||||
if simulate_acc_method == "multinomial":
|
||||
simulated_values = torch.normal(
|
||||
mean=simulate_acc_len,
|
||||
std=1.0,
|
||||
size=(1,),
|
||||
device="cpu",
|
||||
)
|
||||
# clamp simulated values to be between 1 and self.spec_steps
|
||||
simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1)
|
||||
simulate_acc_len = int(simulated_values.round().item())
|
||||
elif simulate_acc_method == "match-expected":
|
||||
# multinomial sampling does not match the expected length
|
||||
# we keep it for the sake of compatibility of existing tests
|
||||
# but it's better to use "match-expected" for the cases that need to
|
||||
# match the expected length, One caveat is that this will only sample
|
||||
# either round down or round up of the expected length
|
||||
simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len))
|
||||
lower = int(simulate_acc_len // 1)
|
||||
upper = lower + 1 if lower < spec_steps + 1 else lower
|
||||
if lower == upper:
|
||||
simulate_acc_len = lower
|
||||
else:
|
||||
weight_upper = simulate_acc_len - lower
|
||||
weight_lower = 1.0 - weight_upper
|
||||
probs = torch.tensor([weight_lower, weight_upper], device="cpu")
|
||||
sampled_index = torch.multinomial(probs, num_samples=1)
|
||||
simulate_acc_len = lower if sampled_index == 0 else upper
|
||||
else:
|
||||
raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}")
|
||||
|
||||
accept_indx_first_col = accept_index[:, 0].view(-1, 1)
|
||||
sim_accept_index = torch.full(
|
||||
(bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda"
|
||||
)
|
||||
sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange(
|
||||
simulate_acc_len, device=accept_index.device
|
||||
)
|
||||
accept_length.fill_(simulate_acc_len - 1)
|
||||
predict.fill_(100) # some legit token id
|
||||
return sim_accept_index
|
||||
|
||||
|
||||
def traverse_tree(
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
draft_tokens: torch.Tensor,
|
||||
grammar: BaseGrammarObject,
|
||||
allocate_token_bitmask: torch.Tensor,
|
||||
):
|
||||
"""
|
||||
Traverse the tree constructed by the draft model to generate the logits mask.
|
||||
"""
|
||||
assert (
|
||||
retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape
|
||||
)
|
||||
|
||||
allocate_token_bitmask.fill_(0)
|
||||
|
||||
def dfs(
|
||||
curr: int,
|
||||
retrieve_next_token: torch.Tensor,
|
||||
retrieve_next_sibling: torch.Tensor,
|
||||
parent_pos: int,
|
||||
):
|
||||
if curr == 0:
|
||||
# the first token generated by the target model, and thus it is always
|
||||
# accepted from the previous iteration
|
||||
accepted = True
|
||||
else:
|
||||
parent_bitmask = allocate_token_bitmask[parent_pos]
|
||||
curr_token_id = draft_tokens[curr]
|
||||
# 32 boolean bitmask values are packed into 32-bit integers
|
||||
accepted = (
|
||||
parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32))
|
||||
) != 0
|
||||
|
||||
if accepted:
|
||||
if curr != 0:
|
||||
# Accept the current token
|
||||
grammar.accept_token(draft_tokens[curr])
|
||||
if not grammar.is_terminated():
|
||||
# Generate the bitmask for the current token
|
||||
grammar.fill_vocab_mask(allocate_token_bitmask, curr)
|
||||
if retrieve_next_token[curr] != -1:
|
||||
# Visit the child node
|
||||
dfs(
|
||||
retrieve_next_token[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
curr,
|
||||
)
|
||||
|
||||
if curr != 0:
|
||||
# Rollback the current token
|
||||
grammar.rollback(1)
|
||||
|
||||
if retrieve_next_sibling[curr] != -1:
|
||||
# Visit the sibling node
|
||||
dfs(
|
||||
retrieve_next_sibling[curr],
|
||||
retrieve_next_token,
|
||||
retrieve_next_sibling,
|
||||
parent_pos,
|
||||
)
|
||||
|
||||
dfs(0, retrieve_next_token, retrieve_next_sibling, -1)
|
||||
|
||||
|
||||
def generate_token_bitmask(
|
||||
reqs: List[Req],
|
||||
verify_input: EagleVerifyInput,
|
||||
retrieve_next_token_cpu: torch.Tensor,
|
||||
retrieve_next_sibling_cpu: torch.Tensor,
|
||||
draft_tokens_cpu: torch.Tensor,
|
||||
vocab_size: int,
|
||||
):
|
||||
"""
|
||||
Generate the logit mask for structured output.
|
||||
Draft model's token can be either valid or invalid with respect to the grammar.
|
||||
We need to perform DFS to
|
||||
1. figure out which tokens are accepted by the grammar.
|
||||
2. if so, what is the corresponding logit mask.
|
||||
"""
|
||||
|
||||
num_draft_tokens = draft_tokens_cpu.shape[-1]
|
||||
|
||||
allocate_token_bitmask = None
|
||||
assert len(reqs) == retrieve_next_token_cpu.shape[0]
|
||||
grammar = None
|
||||
for i, req in enumerate(reqs):
|
||||
if req.grammar is not None:
|
||||
if allocate_token_bitmask is None:
|
||||
allocate_token_bitmask = req.grammar.allocate_vocab_mask(
|
||||
vocab_size=vocab_size,
|
||||
batch_size=draft_tokens_cpu.numel(),
|
||||
device="cpu",
|
||||
)
|
||||
grammar = req.grammar
|
||||
s = time.perf_counter()
|
||||
traverse_tree(
|
||||
retrieve_next_token_cpu[i],
|
||||
retrieve_next_sibling_cpu[i],
|
||||
draft_tokens_cpu[i],
|
||||
req.grammar,
|
||||
allocate_token_bitmask[
|
||||
i * num_draft_tokens : (i + 1) * num_draft_tokens
|
||||
],
|
||||
)
|
||||
tree_traverse_time = time.perf_counter() - s
|
||||
if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Bit mask generation took {tree_traverse_time} seconds with "
|
||||
f"grammar: {req.grammar}"
|
||||
)
|
||||
|
||||
verify_input.grammar = grammar
|
||||
return allocate_token_bitmask
|
||||
@@ -30,7 +30,8 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
)
|
||||
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
||||
from sglang.srt.operations_strategy import OperationsStrategy
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.eagle_info import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpecInput
|
||||
from sglang.srt.utils import BumpAllocator, empty_context, get_bool_env_var, is_hip
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -48,7 +49,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
def get_token_num_per_seq(
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]] = None,
|
||||
spec_info: Optional[SpecInput] = None,
|
||||
):
|
||||
if forward_mode.is_target_verify():
|
||||
return spec_info.draft_token_num
|
||||
@@ -273,7 +274,7 @@ def compute_split_token_index(
|
||||
def compute_split_indices_for_cuda_graph_replay(
|
||||
forward_mode: ForwardMode,
|
||||
cuda_graph_num_tokens: int,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
forward_mode_for_tbo_split = (
|
||||
forward_mode if forward_mode != ForwardMode.IDLE else ForwardMode.DECODE
|
||||
@@ -333,7 +334,7 @@ class TboCudaGraphRunnerPlugin:
|
||||
forward_mode: ForwardMode,
|
||||
bs: int,
|
||||
num_token_non_padded: int,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[SpecInput],
|
||||
):
|
||||
token_num_per_seq = get_token_num_per_seq(
|
||||
forward_mode=forward_mode, spec_info=spec_info
|
||||
|
||||
Reference in New Issue
Block a user