[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
|
||||
batch = ScheduleBatch.init_new(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=model_runner.req_to_token_pool,
|
||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||
tree_cache=None,
|
||||
model_config=model_runner.model_config,
|
||||
enable_overlap=False,
|
||||
@@ -326,7 +326,7 @@ def latency_test_run_once(
|
||||
|
||||
# Clear the pools.
|
||||
model_runner.req_to_token_pool.clear()
|
||||
model_runner.token_to_kv_pool.clear()
|
||||
model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
measurement_results = {
|
||||
"run_name": run_name,
|
||||
|
||||
@@ -20,14 +20,15 @@ import triton.language as tl
|
||||
|
||||
from sglang.global_config import global_config
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.utils import is_flashinfer_available
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.spec_info import SpecInfo
|
||||
|
||||
if is_flashinfer_available():
|
||||
from flashinfer import (
|
||||
@@ -36,6 +37,7 @@ if is_flashinfer_available():
|
||||
BatchPrefillWithRaggedKVCacheWrapper,
|
||||
)
|
||||
from flashinfer.cascade import merge_state
|
||||
from flashinfer.decode import PosEncodingMode
|
||||
|
||||
|
||||
class WrapperDispatch(Enum):
|
||||
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
device=model_runner.device,
|
||||
)
|
||||
self.workspace_buffer = global_workspace_buffer
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
if kv_indptr_buf is None:
|
||||
self.kv_indptr = [
|
||||
@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
assert self.num_wrappers == 1
|
||||
self.kv_last_page_len = kv_last_page_len_buf
|
||||
|
||||
self.qo_indptr = [
|
||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = [
|
||||
torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
for _ in range(self.num_wrappers)
|
||||
]
|
||||
|
||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||
self.workspace_buffer, "NHD"
|
||||
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.indices_updater_decode.update(
|
||||
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if spec_info is None:
|
||||
bs = len(req_pool_indices)
|
||||
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
self.req_to_token.shape[1],
|
||||
)
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput)
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
bs = kv_indptr.shape[0] - 1
|
||||
|
||||
wrapper.begin_forward(
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
if use_ragged:
|
||||
paged_kernel_lens = prefix_lens
|
||||
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[SpecInfo],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
bs = len(req_pool_indices)
|
||||
bs = len(seq_lens)
|
||||
if spec_info is None:
|
||||
assert len(seq_lens) == len(req_pool_indices)
|
||||
# Normal extend
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
req_pool_indices,
|
||||
paged_kernel_lens,
|
||||
paged_kernel_lens_sum,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
)
|
||||
|
||||
|
||||
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||
# This is used to remove some host-to-device copy overhead.
|
||||
global global_override_indptr_cpu
|
||||
|
||||
|
||||
class FlashInferMultiStepDraftBackend:
|
||||
"""
|
||||
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.topk = topk
|
||||
self.speculative_num_steps = speculative_num_steps
|
||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||
self.kv_indptr = torch.zeros(
|
||||
(
|
||||
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
|
||||
kv_last_page_len_buf=self.kv_last_page_len,
|
||||
)
|
||||
)
|
||||
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
|
||||
triton.next_power_of_2(bs),
|
||||
)
|
||||
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
|
||||
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
||||
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
||||
global global_override_indptr_cpu
|
||||
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||
]
|
||||
global_override_indptr_cpu = indptr_cpu_whole[i]
|
||||
call_fn(i, forward_batch)
|
||||
|
||||
global_override_indptr_cpu = None
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
kv_indices = torch.zeros(
|
||||
(
|
||||
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
assert forward_batch.spec_info is not None
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
forward_batch.spec_info.kv_indptr = (
|
||||
forward_batch.spec_info.kv_indptr.clone()
|
||||
)
|
||||
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
|
||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def create_flashinfer_kv_indices_triton(
|
||||
req_to_token_ptr, # [max_batch, max_context_len]
|
||||
req_pool_indices_ptr,
|
||||
page_kernel_lens_ptr,
|
||||
kv_indptr,
|
||||
kv_start_idx,
|
||||
kv_indices_ptr,
|
||||
req_to_token_ptr_stride: tl.constexpr,
|
||||
):
|
||||
BLOCK_SIZE: tl.constexpr = 512
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
||||
|
||||
kv_start = 0
|
||||
kv_end = 0
|
||||
if kv_start_idx:
|
||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
||||
kv_end = kv_start
|
||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
||||
|
||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
||||
for i in range(num_loop):
|
||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
||||
mask = offset < kv_end - kv_start
|
||||
data = tl.load(
|
||||
req_to_token_ptr
|
||||
+ req_pool_index * req_to_token_ptr_stride
|
||||
+ kv_start
|
||||
+ offset,
|
||||
mask=mask,
|
||||
)
|
||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
||||
|
||||
|
||||
def should_use_tensor_core(
|
||||
kv_cache_dtype: torch.dtype,
|
||||
num_attention_heads: int,
|
||||
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
|
||||
if env_override is not None:
|
||||
return env_override.lower() == "true"
|
||||
|
||||
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||
try:
|
||||
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||
|
||||
if not _grouped_size_compiled_for_decode_kernels(
|
||||
num_attention_heads,
|
||||
num_kv_heads,
|
||||
):
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
except (ImportError, AttributeError):
|
||||
pass
|
||||
|
||||
# Calculate GQA group size
|
||||
gqa_group_size = num_attention_heads // num_kv_heads
|
||||
|
||||
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
|
||||
sm_scale: Optional[float] = None,
|
||||
rope_scale: Optional[float] = None,
|
||||
rope_theta: Optional[float] = None,
|
||||
**kwargs,
|
||||
non_blocking: bool = True,
|
||||
) -> None:
|
||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
||||
"""
|
||||
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||
Modifications:
|
||||
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
||||
- Remove unnecessary host-to-device copy for the metadata buffers.
|
||||
"""
|
||||
batch_size = len(last_page_len)
|
||||
if logits_soft_cap is None:
|
||||
logits_soft_cap = 0.0
|
||||
|
||||
if self.is_cuda_graph_enabled:
|
||||
if batch_size != self._fixed_batch_size:
|
||||
raise ValueError(
|
||||
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
|
||||
raise ValueError(
|
||||
"The size of indices should be less than or equal to the allocated buffer"
|
||||
)
|
||||
# Skip these copies
|
||||
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||
else:
|
||||
self._paged_kv_indptr_buf = indptr
|
||||
self._paged_kv_indices_buf = indices
|
||||
self._paged_kv_last_page_len_buf = last_page_len
|
||||
|
||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||
if not q_data_type:
|
||||
q_data_type = data_type
|
||||
|
||||
if not hasattr(self, "empty_q_data"):
|
||||
self.empty_q_data = torch.empty(
|
||||
0,
|
||||
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
|
||||
),
|
||||
)
|
||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||
|
||||
empty_q_data = self.empty_q_data
|
||||
empty_kv_cache = self.empty_kv_cache
|
||||
stream = torch.cuda.current_stream()
|
||||
|
||||
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
forward_batch.req_pool_indices,
|
||||
forward_batch.seq_lens,
|
||||
None,
|
||||
self.req_to_token,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -22,7 +22,7 @@ from typing import List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
|
||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -128,7 +128,7 @@ class HiCacheController:
|
||||
def __init__(
|
||||
self,
|
||||
mem_pool_device: MHATokenToKVPool,
|
||||
mem_pool_host: MLATokenToKVPoolHost,
|
||||
mem_pool_host: MHATokenToKVPoolHost,
|
||||
write_policy: str = "write_through_selective",
|
||||
):
|
||||
|
||||
|
||||
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
|
||||
|
||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||
|
||||
# Put some global args for easy access
|
||||
@@ -523,7 +521,7 @@ class ScheduleBatch:
|
||||
# Request, memory pool, and cache
|
||||
reqs: List[Req]
|
||||
req_to_token_pool: ReqToTokenPool = None
|
||||
token_to_kv_pool: BaseTokenToKVPool = None
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
||||
tree_cache: BasePrefixCache = None
|
||||
|
||||
# Batch configs
|
||||
@@ -596,7 +594,7 @@ class ScheduleBatch:
|
||||
cls,
|
||||
reqs: List[Req],
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
tree_cache: BasePrefixCache,
|
||||
model_config: ModelConfig,
|
||||
enable_overlap: bool,
|
||||
@@ -606,7 +604,7 @@ class ScheduleBatch:
|
||||
return cls(
|
||||
reqs=reqs,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool=token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
tree_cache=tree_cache,
|
||||
model_config=model_config,
|
||||
enable_overlap=enable_overlap,
|
||||
@@ -637,19 +635,19 @@ class ScheduleBatch:
|
||||
return req_pool_indices
|
||||
|
||||
def alloc_token_slots(self, num_tokens: int):
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||
|
||||
if out_cache_loc is None:
|
||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||
logger.error(
|
||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||
f"Try to allocate {num_tokens} tokens.\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||
)
|
||||
if self.tree_cache is not None:
|
||||
self.tree_cache.pretty_print()
|
||||
@@ -917,12 +915,12 @@ class ScheduleBatch:
|
||||
|
||||
def check_decode_mem(self, buf_multiplier=1):
|
||||
bs = len(self.reqs) * buf_multiplier
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||
|
||||
if self.token_to_kv_pool.available_size() >= bs:
|
||||
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||
return True
|
||||
|
||||
return False
|
||||
@@ -945,6 +943,10 @@ class ScheduleBatch:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
|
||||
def get_required_tokens(num_reqs: int):
|
||||
headroom_for_spec_decode = 0
|
||||
if server_args.speculative_algorithm:
|
||||
@@ -958,18 +960,15 @@ class ScheduleBatch:
|
||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||
)
|
||||
|
||||
retracted_reqs = []
|
||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||
first_iter = True
|
||||
while (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
< get_required_tokens(len(sorted_indices))
|
||||
or first_iter
|
||||
):
|
||||
if len(sorted_indices) == 1:
|
||||
# Corner case: only one request left
|
||||
assert (
|
||||
self.token_to_kv_pool.available_size() > 0
|
||||
self.token_to_kv_pool_allocator.available_size() > 0
|
||||
), "No space left for only one request"
|
||||
break
|
||||
|
||||
@@ -983,7 +982,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
del self.tree_cache.entries[req.rid]
|
||||
else:
|
||||
@@ -992,7 +991,7 @@ class ScheduleBatch:
|
||||
token_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||
]
|
||||
self.token_to_kv_pool.free(token_indices)
|
||||
self.token_to_kv_pool_allocator.free(token_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
|
||||
# release the last node
|
||||
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
|
||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||
residual_size = (
|
||||
len(sorted_indices) * global_config.retract_decode_steps
|
||||
- self.token_to_kv_pool.available_size()
|
||||
- self.token_to_kv_pool_allocator.available_size()
|
||||
)
|
||||
residual_size = max(0, residual_size)
|
||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
||||
self.tree_cache.evict(
|
||||
residual_size, self.token_to_kv_pool_allocator.free
|
||||
)
|
||||
|
||||
req.reset_for_retract()
|
||||
|
||||
self.filter_batch(keep_indices=sorted_indices)
|
||||
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
|
||||
if self.spec_info:
|
||||
self.spec_info.merge_batch(other.spec_info)
|
||||
|
||||
def get_model_worker_batch(self):
|
||||
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||
if self.forward_mode.is_decode_or_idle():
|
||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||
else:
|
||||
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
|
||||
req_pool_indices: torch.Tensor
|
||||
# The sequence length
|
||||
seq_lens: torch.Tensor
|
||||
# The indices of output tokens in the token_to_kv_pool
|
||||
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||
out_cache_loc: torch.Tensor
|
||||
|
||||
# The sum of all sequence lengths
|
||||
|
||||
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
||||
from sglang.srt.managers.schedule_batch import (
|
||||
Req,
|
||||
ScheduleBatch,
|
||||
global_server_args_dict,
|
||||
)
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||
|
||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||
@@ -75,7 +79,7 @@ class SchedulePolicy:
|
||||
|
||||
# It is used to find the matching prefix for in-batch prefix caching.
|
||||
self.waiting_queue_radix_tree = RadixCache(
|
||||
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
||||
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
|
||||
)
|
||||
|
||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||
@@ -251,7 +255,7 @@ class PrefillAdder:
|
||||
def __init__(
|
||||
self,
|
||||
tree_cache: BasePrefixCache,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
running_batch: ScheduleBatch,
|
||||
new_token_ratio: float,
|
||||
rem_input_tokens: int,
|
||||
@@ -259,7 +263,7 @@ class PrefillAdder:
|
||||
mixed_with_decode_tokens: int = 0,
|
||||
):
|
||||
self.tree_cache = tree_cache
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.running_batch = running_batch
|
||||
self.new_token_ratio = new_token_ratio
|
||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||
@@ -291,7 +295,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def rem_total_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.rem_total_token_offset
|
||||
)
|
||||
@@ -299,7 +303,7 @@ class PrefillAdder:
|
||||
@property
|
||||
def cur_rem_tokens(self):
|
||||
return (
|
||||
self.token_to_kv_pool.available_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
- self.cur_rem_token_offset
|
||||
)
|
||||
@@ -332,7 +336,6 @@ class PrefillAdder:
|
||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||
self.can_run_list.append(req)
|
||||
|
||||
self._prefill_one_req(
|
||||
0,
|
||||
req.extend_input_len,
|
||||
@@ -400,8 +403,8 @@ class PrefillAdder:
|
||||
tokens_freed += tokens_occupied
|
||||
|
||||
if (
|
||||
self.rem_chunk_tokens is None
|
||||
or req.extend_input_len <= self.rem_chunk_tokens
|
||||
self.rem_chunk_tokens is None # chunked prefill is disabled
|
||||
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
|
||||
):
|
||||
# Non-chunked prefill
|
||||
self.can_run_list.append(req)
|
||||
@@ -411,10 +414,11 @@ class PrefillAdder:
|
||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[:trunc_len]
|
||||
@@ -457,10 +461,11 @@ class PrefillAdder:
|
||||
),
|
||||
)
|
||||
else:
|
||||
if self.rem_chunk_tokens == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
# Chunked prefill
|
||||
trunc_len = self.rem_chunk_tokens
|
||||
if trunc_len == 0:
|
||||
return AddReqResult.OTHER
|
||||
|
||||
req.extend_input_len = trunc_len
|
||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||
|
||||
@@ -164,7 +164,7 @@ class Scheduler:
|
||||
self.server_args.speculative_num_draft_tokens
|
||||
+ (
|
||||
self.server_args.speculative_eagle_topk
|
||||
* self.server_args.speculative_num_steps
|
||||
* self.server_args.speculative_num_draft_tokens
|
||||
)
|
||||
)
|
||||
if not self.spec_algorithm.is_none()
|
||||
@@ -309,7 +309,9 @@ class Scheduler:
|
||||
)
|
||||
|
||||
# Init memory pool and cache
|
||||
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
self.tp_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
if (
|
||||
server_args.chunked_prefill_size is not None
|
||||
@@ -317,18 +319,18 @@ class Scheduler:
|
||||
):
|
||||
self.tree_cache = ChunkCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
if self.enable_hierarchical_cache:
|
||||
self.tree_cache = HiRadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
else:
|
||||
self.tree_cache = RadixCache(
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool=self.token_to_kv_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
disable=server_args.disable_radix_cache,
|
||||
)
|
||||
|
||||
@@ -458,7 +460,6 @@ class Scheduler:
|
||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||
(ProfileReq, self.profile),
|
||||
(GetInternalStateReq, self.get_internal_state),
|
||||
(SetInternalStateReq, self.set_internal_state),
|
||||
]
|
||||
)
|
||||
|
||||
@@ -809,7 +810,8 @@ class Scheduler:
|
||||
running_bs: int,
|
||||
):
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
self._largest_prefill_len = max(
|
||||
self._largest_prefill_len, adder.log_input_tokens
|
||||
@@ -844,7 +846,8 @@ class Scheduler:
|
||||
self.num_generated_tokens = 0
|
||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||
num_used = self.max_total_num_tokens - (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
|
||||
if RECORD_STEP_TIME:
|
||||
@@ -894,7 +897,8 @@ class Scheduler:
|
||||
|
||||
def check_memory(self):
|
||||
available_size = (
|
||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
||||
self.token_to_kv_pool_allocator.available_size()
|
||||
+ self.tree_cache.evictable_size()
|
||||
)
|
||||
protected_size = self.tree_cache.protected_size()
|
||||
memory_leak = available_size != (
|
||||
@@ -999,7 +1003,7 @@ class Scheduler:
|
||||
# Prefill policy
|
||||
adder = PrefillAdder(
|
||||
self.tree_cache,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.running_batch,
|
||||
self.new_token_ratio,
|
||||
self.max_prefill_tokens,
|
||||
@@ -1099,7 +1103,7 @@ class Scheduler:
|
||||
new_batch = ScheduleBatch.init_new(
|
||||
can_run_list,
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1143,8 +1147,6 @@ class Scheduler:
|
||||
|
||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
||||
self.new_token_ratio = new_token_ratio
|
||||
if self.draft_worker:
|
||||
self.draft_worker.finish_request(retracted_reqs)
|
||||
|
||||
logger.info(
|
||||
"Decode out of memory happened. "
|
||||
@@ -1184,11 +1186,12 @@ class Scheduler:
|
||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
bid = model_worker_batch.bid
|
||||
else:
|
||||
(
|
||||
logits_output,
|
||||
next_token_ids,
|
||||
model_worker_batch,
|
||||
bid,
|
||||
num_accepted_tokens,
|
||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||
self.spec_num_total_accepted_tokens += (
|
||||
@@ -1214,7 +1217,7 @@ class Scheduler:
|
||||
next_token_ids=next_token_ids,
|
||||
extend_input_len_per_req=extend_input_len_per_req,
|
||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||
bid=model_worker_batch.bid,
|
||||
bid=bid,
|
||||
)
|
||||
else: # embedding or reward model
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
@@ -1230,6 +1233,7 @@ class Scheduler:
|
||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||
):
|
||||
if batch.forward_mode.is_decode():
|
||||
assert isinstance(result, GenerationBatchResult)
|
||||
self.process_batch_result_decode(batch, result)
|
||||
if batch.is_empty():
|
||||
self.running_batch = None
|
||||
@@ -1302,7 +1306,7 @@ class Scheduler:
|
||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||
# Free the one delayed token for the mixed decode batch
|
||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||
continue
|
||||
|
||||
if req.is_chunked <= 0:
|
||||
@@ -1420,23 +1424,27 @@ class Scheduler:
|
||||
self.num_generated_tokens += len(batch.reqs)
|
||||
|
||||
if self.enable_overlap:
|
||||
assert batch.spec_algorithm.is_none()
|
||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||
next_token_logprobs = logits_output.next_token_logprobs
|
||||
else:
|
||||
elif batch.spec_algorithm.is_none():
|
||||
# spec decoding handles output logprobs inside verify process.
|
||||
next_token_ids = next_token_ids.tolist()
|
||||
if batch.return_logprob:
|
||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||
|
||||
self.token_to_kv_pool.free_group_begin()
|
||||
self.token_to_kv_pool_allocator.free_group_begin()
|
||||
|
||||
# Check finish condition
|
||||
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
||||
# We should ignore using next_token_ids for spec decoding cases.
|
||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||
if req.is_retracted:
|
||||
continue
|
||||
|
||||
if self.enable_overlap and req.finished():
|
||||
# Free the one delayed token
|
||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
||||
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||
continue
|
||||
|
||||
if batch.spec_algorithm.is_none():
|
||||
@@ -1479,7 +1487,7 @@ class Scheduler:
|
||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||
self.stream_output(batch.reqs, batch.return_logprob)
|
||||
|
||||
self.token_to_kv_pool.free_group_end()
|
||||
self.token_to_kv_pool_allocator.free_group_end()
|
||||
|
||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||
if (
|
||||
@@ -1718,9 +1726,6 @@ class Scheduler:
|
||||
and not self.model_config.is_multimodal_gen
|
||||
)
|
||||
):
|
||||
if self.draft_worker and req.finished():
|
||||
self.draft_worker.finish_request(req)
|
||||
|
||||
rids.append(req.rid)
|
||||
finished_reasons.append(
|
||||
req.finished_reason.to_json() if req.finished_reason else None
|
||||
@@ -1860,7 +1865,7 @@ class Scheduler:
|
||||
idle_batch = ScheduleBatch.init_new(
|
||||
[],
|
||||
self.req_to_token_pool,
|
||||
self.token_to_kv_pool,
|
||||
self.token_to_kv_pool_allocator,
|
||||
self.tree_cache,
|
||||
self.model_config,
|
||||
self.enable_overlap,
|
||||
@@ -1916,11 +1921,11 @@ class Scheduler:
|
||||
if self.grammar_backend:
|
||||
self.grammar_backend.reset()
|
||||
self.req_to_token_pool.clear()
|
||||
self.token_to_kv_pool.clear()
|
||||
self.token_to_kv_pool_allocator.clear()
|
||||
|
||||
if not self.spec_algorithm.is_none():
|
||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
||||
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
||||
|
||||
self.num_generated_tokens = 0
|
||||
self.forward_ct_decode = 0
|
||||
|
||||
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
|
||||
ResumeMemoryOccupationReqInput,
|
||||
ResumeMemoryOccupationReqOutput,
|
||||
SessionParams,
|
||||
SetInternalStateReq,
|
||||
SetInternalStateReqOutput,
|
||||
TokenizedEmbeddingReqInput,
|
||||
TokenizedGenerateReqInput,
|
||||
UpdateWeightFromDiskReqInput,
|
||||
@@ -257,9 +255,6 @@ class TokenizerManager:
|
||||
self.get_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
self.set_internal_state_communicator = _Communicator(
|
||||
self.send_to_scheduler, server_args.dp_size
|
||||
)
|
||||
|
||||
self._result_dispatcher = TypeBasedDispatcher(
|
||||
[
|
||||
@@ -309,10 +304,6 @@ class TokenizerManager:
|
||||
GetInternalStateReqOutput,
|
||||
self.get_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(
|
||||
SetInternalStateReqOutput,
|
||||
self.set_internal_state_communicator.handle_recv,
|
||||
),
|
||||
(HealthCheckOutput, lambda x: None),
|
||||
]
|
||||
)
|
||||
@@ -774,14 +765,6 @@ class TokenizerManager:
|
||||
)
|
||||
return res[0].internal_state
|
||||
|
||||
async def set_internal_state(
|
||||
self, obj: SetInternalStateReq
|
||||
) -> SetInternalStateReqOutput:
|
||||
res: List[SetInternalStateReqOutput] = (
|
||||
await self.set_internal_state_communicator(obj)
|
||||
)
|
||||
return res[0]
|
||||
|
||||
def get_log_request_metadata(self):
|
||||
max_length = None
|
||||
skip_names = None
|
||||
|
||||
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
||||
UpdateWeightsFromTensorReqInput,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
@@ -49,6 +50,8 @@ class TpModelWorker:
|
||||
dp_rank: Optional[int],
|
||||
nccl_port: int,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.tp_rank = tp_rank
|
||||
@@ -77,6 +80,8 @@ class TpModelWorker:
|
||||
nccl_port=nccl_port,
|
||||
server_args=server_args,
|
||||
is_draft_worker=is_draft_worker,
|
||||
req_to_token_pool=req_to_token_pool,
|
||||
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||
)
|
||||
if server_args.skip_tokenizer_init:
|
||||
self.tokenizer = self.processor = None
|
||||
@@ -154,7 +159,7 @@ class TpModelWorker:
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.model_runner.req_to_token_pool,
|
||||
self.model_runner.token_to_kv_pool,
|
||||
self.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def forward_batch_generation(
|
||||
|
||||
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
|
||||
def get_memory_pool(self):
|
||||
return (
|
||||
self.worker.model_runner.req_to_token_pool,
|
||||
self.worker.model_runner.token_to_kv_pool,
|
||||
self.worker.model_runner.token_to_kv_pool_allocator,
|
||||
)
|
||||
|
||||
def forward_thread_func(self):
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||
|
||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
|
||||
|
||||
class ChunkCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
):
|
||||
self.disable = True
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||
|
||||
self.reset()
|
||||
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
|
||||
req.req_pool_idx, :token_id_len
|
||||
]
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
|
||||
if req.rid in self.entries:
|
||||
del self.entries[req.rid]
|
||||
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
|
||||
|
||||
def protected_size(self):
|
||||
return 0
|
||||
|
||||
def pretty_print(self):
|
||||
return ""
|
||||
|
||||
@@ -7,8 +7,8 @@ import torch
|
||||
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.memory_pool import (
|
||||
BaseTokenToKVPool,
|
||||
MLATokenToKVPoolHost,
|
||||
MHATokenToKVPool,
|
||||
MHATokenToKVPoolHost,
|
||||
ReqToTokenPool,
|
||||
)
|
||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool: MHATokenToKVPool,
|
||||
):
|
||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
||||
self.cache_controller = HiCacheController(
|
||||
token_to_kv_pool, self.token_to_kv_pool_host
|
||||
)
|
||||
|
||||
@@ -20,9 +20,12 @@ Memory pool.
|
||||
|
||||
SGLang has two levels of memory pool.
|
||||
ReqToTokenPool maps a a request to its token locations.
|
||||
BaseTokenToKVPool maps a token location to its KV cache data.
|
||||
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
||||
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
||||
by TokenToKVPoolAllocator
|
||||
"""
|
||||
|
||||
import abc
|
||||
import logging
|
||||
import threading
|
||||
from enum import IntEnum
|
||||
@@ -89,7 +92,7 @@ class ReqToTokenPool:
|
||||
self.free_slots = list(range(self.size))
|
||||
|
||||
|
||||
class BaseTokenToKVPool:
|
||||
class TokenToKVPoolAllocator:
|
||||
"""A memory pool that maps a token location to its kv cache data."""
|
||||
|
||||
def __init__(
|
||||
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
|
||||
):
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.device = device
|
||||
|
||||
self.free_slots = None
|
||||
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
|
||||
self.is_in_free_group = False
|
||||
self.free_group = []
|
||||
|
||||
|
||||
class KVCache(abc.ABC):
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abc.abstractmethod
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
class MHATokenToKVPool(KVCache):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||
|
||||
|
||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
class MLATokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
device: str,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
||||
self.kv_buffer[layer_id][loc] = cache_k
|
||||
|
||||
|
||||
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
class DoubleSparseTokenToKVPool(KVCache):
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
||||
heavy_channel_num: int,
|
||||
enable_memory_saver: bool,
|
||||
):
|
||||
super().__init__(size, dtype, device)
|
||||
|
||||
self.size = size
|
||||
self.dtype = dtype
|
||||
self.device = device
|
||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||
self.store_dtype = torch.uint8
|
||||
else:
|
||||
self.store_dtype = dtype
|
||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||
enable=enable_memory_saver
|
||||
)
|
||||
@@ -437,12 +460,12 @@ def synchronized(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
class MLATokenToKVPoolHost:
|
||||
class MHATokenToKVPoolHost:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
device_pool: MHATokenToKVPool,
|
||||
host_to_device_ratio: float = 4.0,
|
||||
host_to_device_ratio: float = 2.0,
|
||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||
device: str = "cpu",
|
||||
):
|
||||
|
||||
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
||||
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
|
||||
def __init__(
|
||||
self,
|
||||
req_to_token_pool: ReqToTokenPool,
|
||||
token_to_kv_pool: BaseTokenToKVPool,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
disable: bool = False,
|
||||
):
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool = token_to_kv_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
self.disable = disable
|
||||
self.reset()
|
||||
|
||||
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
|
||||
kv_indices = self.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx, :token_ids_len
|
||||
]
|
||||
self.token_to_kv_pool.free(kv_indices)
|
||||
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
return
|
||||
|
||||
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# Remove req slot release the cache lock
|
||||
self.req_to_token_pool.free(req.req_pool_idx)
|
||||
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
|
||||
|
||||
# Radix Cache takes one ref in memory pool
|
||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
||||
self.token_to_kv_pool_allocator.free(
|
||||
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||
)
|
||||
|
||||
# The prefix indices could be updated, reuse it
|
||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||
|
||||
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
||||
MHATokenToKVPool,
|
||||
MLATokenToKVPool,
|
||||
ReqToTokenPool,
|
||||
TokenToKVPoolAllocator,
|
||||
)
|
||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
@@ -98,6 +99,8 @@ class ModelRunner:
|
||||
nccl_port: int,
|
||||
server_args: ServerArgs,
|
||||
is_draft_worker: bool = False,
|
||||
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||
):
|
||||
# Parse args
|
||||
self.model_config = model_config
|
||||
@@ -115,6 +118,8 @@ class ModelRunner:
|
||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.req_to_token_pool = req_to_token_pool
|
||||
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||
|
||||
# Model-specific adjustment
|
||||
if (
|
||||
@@ -257,8 +262,8 @@ class ModelRunner:
|
||||
|
||||
def init_torch_distributed(self):
|
||||
logger.info("Init torch distributed begin.")
|
||||
|
||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||
|
||||
if self.device == "cuda":
|
||||
backend = "nccl"
|
||||
elif self.device == "xpu":
|
||||
@@ -660,12 +665,25 @@ class ModelRunner:
|
||||
if not self.spec_algorithm.is_none():
|
||||
if self.is_draft_worker:
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
max_num_reqs = self.server_args.max_num_reqs
|
||||
else:
|
||||
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
||||
# can be concurrently allocated, so we should give a headroom for it.
|
||||
self.server_args.draft_runner_cache_size = (
|
||||
self.max_total_num_tokens
|
||||
+ max_num_reqs * self.server_args.speculative_num_steps
|
||||
# draft
|
||||
+ max_num_reqs
|
||||
* self.server_args.speculative_num_steps
|
||||
* self.server_args.speculative_eagle_topk
|
||||
# verify
|
||||
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
||||
# buffer
|
||||
+ 100
|
||||
)
|
||||
# Target worker and draft worker shares the same indices for the
|
||||
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||
self.server_args.max_num_reqs = max_num_reqs
|
||||
|
||||
if max_total_tokens is not None:
|
||||
if max_total_tokens > self.max_total_num_tokens:
|
||||
@@ -681,12 +699,25 @@ class ModelRunner:
|
||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||
)
|
||||
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
if self.req_to_token_pool is None:
|
||||
self.req_to_token_pool = ReqToTokenPool(
|
||||
size=max_num_reqs + 1,
|
||||
max_context_len=self.model_config.context_len + 4,
|
||||
device=self.device,
|
||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||
)
|
||||
else:
|
||||
# Draft worker shares req_to_token_pool with the target worker.
|
||||
assert self.is_draft_worker
|
||||
|
||||
if self.token_to_kv_pool_allocator is None:
|
||||
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||
self.max_total_num_tokens,
|
||||
dtype=self.kv_cache_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
assert self.is_draft_worker
|
||||
|
||||
if (
|
||||
self.model_config.attention_arch == AttentionArch.MLA
|
||||
|
||||
@@ -280,11 +280,16 @@ class ServerArgs:
|
||||
self.disable_overlap_schedule = True
|
||||
self.prefill_only_one_req = True
|
||||
self.disable_cuda_graph_padding = True
|
||||
self.disable_radix_cache = True
|
||||
self.chunked_prefill_size = -1
|
||||
if self.max_running_requests is None:
|
||||
self.max_running_requests = 32
|
||||
logger.info(
|
||||
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
|
||||
"Overlap scheduler are disabled because of using "
|
||||
"eagle speculative decoding."
|
||||
"Max running request set to 32 because of using eagle speculative decoding."
|
||||
)
|
||||
# The token generated from the verify step is counted.
|
||||
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||
assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||
|
||||
# GGUF
|
||||
if (
|
||||
|
||||
@@ -3,14 +3,8 @@
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
||||
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
|
||||
|
||||
|
||||
def build_tree_kernel_efficient_preprocess(
|
||||
|
||||
@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from typing import TYPE_CHECKING, List
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
create_flashinfer_kv_indices_triton,
|
||||
)
|
||||
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import (
|
||||
build_tree_kernel,
|
||||
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class EagleDraftInput:
|
||||
# The inputs for decode
|
||||
# shape: (b, topk)
|
||||
@@ -46,57 +47,46 @@ class EagleDraftInput:
|
||||
kv_indptr: torch.Tensor = None
|
||||
kv_indices: torch.Tensor = None
|
||||
|
||||
# indices of unfinished requests during extend-after-decode
|
||||
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
||||
keep_indices: List[int] = None
|
||||
|
||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
||||
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
||||
# Prefill only generate 1 token.
|
||||
assert len(self.verified_id) == len(batch.seq_lens)
|
||||
|
||||
pt = 0
|
||||
for i, req in enumerate(batch.reqs):
|
||||
req.req_pool_idx = req_pool_indices[i]
|
||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
||||
assert seq_len - pre_len == req.extend_input_len
|
||||
|
||||
if pre_len > 0:
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
:pre_len
|
||||
] = req.prefix_indices
|
||||
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
||||
out_cache_loc[pt : pt + req.extend_input_len]
|
||||
for i, extend_len in enumerate(batch.extend_lens):
|
||||
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
||||
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||
)
|
||||
|
||||
pt += req.extend_input_len
|
||||
|
||||
# TODO: support batching inputs
|
||||
assert len(batch.extend_lens) == 1
|
||||
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
||||
|
||||
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
||||
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
||||
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
||||
seq_lens_cpu = batch.seq_lens.tolist()
|
||||
assert len(batch.req_pool_indices) == len(batch.reqs)
|
||||
|
||||
pt = 0
|
||||
i = 0
|
||||
for req in batch.reqs:
|
||||
self.keep_indices = []
|
||||
for idx, req in enumerate(batch.reqs):
|
||||
if req.finished():
|
||||
continue
|
||||
self.keep_indices.append(idx)
|
||||
# assert seq_len - pre_len == req.extend_input_len
|
||||
input_len = batch.extend_lens[i]
|
||||
seq_len = seq_lens_cpu[i]
|
||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
||||
seq_len - input_len : seq_len
|
||||
] = batch.out_cache_loc[pt : pt + input_len]
|
||||
pt += input_len
|
||||
i += 1
|
||||
assert pt == batch.out_cache_loc.shape[0]
|
||||
|
||||
self.positions = torch.empty_like(self.verified_id)
|
||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
||||
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||
self.accept_length.add_(1)
|
||||
|
||||
create_extend_spec_info[(self.accept_length.numel(),)](
|
||||
@@ -117,14 +107,22 @@ class EagleDraftInput:
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
bs = self.accept_length.numel()
|
||||
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
||||
req_pool_indices = req_pool_indices[keep_indices]
|
||||
assert req_pool_indices.shape[0] == bs
|
||||
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
||||
|
||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||
|
||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
|
||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
@@ -162,7 +160,21 @@ class EagleDraftInput:
|
||||
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class EagleVerifyOutput:
|
||||
# Draft input batch
|
||||
draft_input: EagleDraftInput
|
||||
# Logit outputs from target worker
|
||||
logits_output: LogitsProcessorOutput
|
||||
# Accepeted token ids including the bonus token
|
||||
verified_id: torch.Tensor
|
||||
# Accepeted token length per sequence in a batch in CPU.
|
||||
accept_length_per_req_cpu: List[int]
|
||||
# Accepeted indices from logits_output.next_token_logits
|
||||
accepeted_indices_cpu: List[int]
|
||||
|
||||
|
||||
@dataclass
|
||||
class EagleVerifyInput:
|
||||
draft_token: torch.Tensor
|
||||
custom_mask: torch.Tensor
|
||||
@@ -267,6 +279,7 @@ class EagleVerifyInput:
|
||||
self,
|
||||
req_pool_indices: torch.Tensor,
|
||||
paged_kernel_lens: torch.Tensor,
|
||||
paged_kernel_lens_sum: int,
|
||||
req_to_token: torch.Tensor,
|
||||
):
|
||||
batch_size = len(req_pool_indices)
|
||||
@@ -285,7 +298,11 @@ class EagleVerifyInput:
|
||||
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||
|
||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||
kv_indices = torch.empty(
|
||||
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
|
||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||
req_to_token,
|
||||
@@ -298,7 +315,21 @@ class EagleVerifyInput:
|
||||
)
|
||||
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
||||
|
||||
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
||||
def verify(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
logits_output: torch.Tensor,
|
||||
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||
) -> torch.Tensor:
|
||||
"""WARNING: This API in-place modifies the states of logits_output
|
||||
|
||||
Verify and find accepted tokens based on logits output and batch
|
||||
(which contains spec decoding information).
|
||||
|
||||
This API updates values inside logits_output based on the accepted
|
||||
tokens. I.e., logits_output.next_token_logits only contains
|
||||
accepeted token logits.
|
||||
"""
|
||||
draft_token = torch.cat(
|
||||
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
||||
dim=-1,
|
||||
@@ -367,7 +398,6 @@ class EagleVerifyInput:
|
||||
|
||||
new_accept_index = []
|
||||
unfinished_index = []
|
||||
finished_extend_len = {} # {rid:accept_length + 1}
|
||||
accept_index_cpu = accept_index.tolist()
|
||||
predict_cpu = predict.tolist()
|
||||
has_finished = False
|
||||
@@ -382,7 +412,6 @@ class EagleVerifyInput:
|
||||
id = predict_cpu[idx]
|
||||
# if not found_finished:
|
||||
req.output_ids.append(id)
|
||||
finished_extend_len[req.rid] = j + 1
|
||||
req.check_finished()
|
||||
if req.finished():
|
||||
has_finished = True
|
||||
@@ -400,11 +429,10 @@ class EagleVerifyInput:
|
||||
accept_index = accept_index[accept_index != -1]
|
||||
accept_length_cpu = accept_length.tolist()
|
||||
verified_id = predict[accept_index]
|
||||
|
||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||
evict_mask[accept_index] = False
|
||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||
batch.token_to_kv_pool.free(mem_need_free_idx)
|
||||
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||
assign_req_to_token_pool[(bs,)](
|
||||
batch.req_pool_indices,
|
||||
batch.req_to_token_pool.req_to_token,
|
||||
@@ -427,20 +455,16 @@ class EagleVerifyInput:
|
||||
]
|
||||
if has_finished:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
||||
unfinished_index
|
||||
]
|
||||
else:
|
||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
||||
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
||||
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
||||
return (
|
||||
draft_input,
|
||||
logits_output,
|
||||
verified_id,
|
||||
finished_extend_len,
|
||||
accept_length_cpu,
|
||||
return EagleVerifyOutput(
|
||||
draft_input=draft_input,
|
||||
logits_output=logits_output,
|
||||
verified_id=verified_id,
|
||||
accept_length_per_req_cpu=accept_length_cpu,
|
||||
accepeted_indices_cpu=accept_index,
|
||||
)
|
||||
|
||||
|
||||
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
|
||||
draft_token_num: tl.constexpr,
|
||||
max_len_upper: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
retrive_index: Pointer to indices of draft tokens
|
||||
accept_mask: Mask indicating which tokens were accepted
|
||||
retrive_cum_len: Cumulative lengths of token sequences in a batch
|
||||
accept_index (out): Accept token indices
|
||||
accept_length (out): Length of accepted tokens per sequence in a batch
|
||||
extract_index (out): Index for last accepted tokens
|
||||
max_len: Maximum length in a batch
|
||||
draft_token_num: Number of tokens speculatively generated
|
||||
max_len_upper An upper bound for token sequence length
|
||||
"""
|
||||
pid = tl.program_id(axis=0)
|
||||
|
||||
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
||||
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
|
||||
tl.store(kv_indptr + zid, base + zid * iters)
|
||||
|
||||
|
||||
@torch.compile
|
||||
@torch.compile(dynamic=True)
|
||||
def select_top_k_tokens(
|
||||
i: int,
|
||||
topk_p: torch.Tensor,
|
||||
@@ -671,13 +707,11 @@ def select_top_k_tokens(
|
||||
.unsqueeze(0)
|
||||
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||
)
|
||||
|
||||
else:
|
||||
# The later decode steps
|
||||
expand_scores = torch.mul(
|
||||
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||
|
||||
topk_cs_p, topk_cs_index = fast_topk(
|
||||
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||
) # (b, topk)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import List, Optional, Union
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from huggingface_hub import snapshot_download
|
||||
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
||||
from sglang.srt.speculative.eagle_utils import (
|
||||
EagleDraftInput,
|
||||
EagleVerifyInput,
|
||||
EagleVerifyOutput,
|
||||
assign_draft_cache_locs,
|
||||
fast_topk,
|
||||
select_top_k_tokens,
|
||||
)
|
||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||
from sglang.srt.utils import get_available_gpu_memory
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
|
||||
nccl_port: int,
|
||||
target_worker: TpModelWorker,
|
||||
):
|
||||
# Override context length with target model's context length
|
||||
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
||||
|
||||
# Do not capture cuda graph in `super().__init__()`
|
||||
# We will capture it later
|
||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||
server_args.disable_cuda_graph = True
|
||||
|
||||
# Load hot token ids
|
||||
# Lossy optimization by using hot tokens
|
||||
if server_args.speculative_token_map is not None:
|
||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||
server_args.json_model_override_args = (
|
||||
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
self.hot_token_id = None
|
||||
|
||||
# We share the allocator with a target worker. Draft/target worker
|
||||
# owns its own KV cache.
|
||||
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||
target_worker.get_memory_pool()
|
||||
)
|
||||
|
||||
# Init target worker
|
||||
super().__init__(
|
||||
gpu_id=gpu_id,
|
||||
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
nccl_port=nccl_port,
|
||||
dp_rank=dp_rank,
|
||||
is_draft_worker=True,
|
||||
req_to_token_pool=self.req_to_token_pool,
|
||||
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||
)
|
||||
self.target_worker = target_worker
|
||||
self.finish_extend_len = []
|
||||
|
||||
# Parse arguments
|
||||
self.topk = server_args.speculative_eagle_topk
|
||||
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
server_args.speculative_algorithm
|
||||
)
|
||||
self.server_args = server_args
|
||||
self.use_nan_detection = self.server_args.enable_nan_detection
|
||||
self.device = self.model_runner.device
|
||||
self.gpu_id = self.model_runner.gpu_id
|
||||
|
||||
# Share the embedding and lm_head
|
||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
head = head.clone()
|
||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||
head.data = head.data[self.hot_token_id]
|
||||
self.model_runner.model.set_embed_and_head(embed, head)
|
||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
||||
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||
backup_disable_cuda_graph
|
||||
)
|
||||
|
||||
# Create multi-step attn backends and cuda graph runners
|
||||
if server_args.attention_backend == "flashinfer":
|
||||
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||
)
|
||||
|
||||
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||
self.init_cuda_graphs()
|
||||
|
||||
def init_cuda_graphs(self):
|
||||
@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
|
||||
return
|
||||
|
||||
tic = time.time()
|
||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
||||
logger.info(
|
||||
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
||||
logger.info(
|
||||
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||
)
|
||||
|
||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
||||
@property
|
||||
def draft_model_runner(self):
|
||||
return self.model_runner
|
||||
|
||||
def forward_batch_speculative_generation(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||
"""Run speculative decoding forward.
|
||||
|
||||
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
||||
the final output batch doesn't have the same state as the input.
|
||||
|
||||
Args:
|
||||
batch: The batch to run forward. The state of the batch is modified as it runs.
|
||||
Returns:
|
||||
A tuple of the final logit output of the target model, next tokens accepeted,
|
||||
the batch id (used for overlap schedule), and number of accepeted tokens.
|
||||
"""
|
||||
assert not batch.spec_algorithm.is_none()
|
||||
if batch.forward_mode.is_decode():
|
||||
# Draft
|
||||
spec_info: EagleVerifyInput = self.draft(batch)
|
||||
|
||||
# Verify
|
||||
(
|
||||
next_draft_input,
|
||||
logits_output,
|
||||
verified_id,
|
||||
self.finish_extend_len,
|
||||
accept_length_cpu,
|
||||
model_worker_batch,
|
||||
) = self.verify(batch, spec_info)
|
||||
batch.spec_info = next_draft_input
|
||||
# if it is None, means all requsets are finished
|
||||
spec_info, to_free_cache_loc = self.draft(batch)
|
||||
logits_output, verify_output, model_worker_batch = self.verify(
|
||||
batch, spec_info
|
||||
)
|
||||
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||
# if it is None, means all requests are finished
|
||||
if batch.spec_info.verified_id is not None:
|
||||
self.forward_draft_extend_after_decode(batch)
|
||||
|
||||
return (
|
||||
logits_output,
|
||||
verified_id,
|
||||
model_worker_batch,
|
||||
sum(accept_length_cpu),
|
||||
verify_output.verified_id,
|
||||
model_worker_batch.bid,
|
||||
sum(verify_output.accept_length_per_req_cpu),
|
||||
)
|
||||
|
||||
else:
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||
self.forward_draft_extend(
|
||||
batch, logits_output.hidden_states, next_token_ids
|
||||
)
|
||||
return logits_output, next_token_ids, bid, 0
|
||||
|
||||
# Forward with the draft model.
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=logits_output.hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
self.forward_draft_extend(batch)
|
||||
return logits_output, next_token_ids, model_worker_batch, 0
|
||||
def forward_target_extend(
|
||||
self, batch: ScheduleBatch
|
||||
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
||||
"""Run the target extend.
|
||||
|
||||
Args:
|
||||
batch: The batch to run. States could be modified.
|
||||
|
||||
Returns:
|
||||
logits_output: The output of logits. It will contain the full hidden states.
|
||||
next_token_ids: Next token ids generated.
|
||||
bid: The model batch ID. Used for overlap schedule.
|
||||
"""
|
||||
# Forward with the target model and get hidden states.
|
||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch
|
||||
)
|
||||
return logits_output, next_token_ids, model_worker_batch.bid
|
||||
|
||||
def draft(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
|
||||
# Parse args
|
||||
num_seqs = batch.batch_size()
|
||||
spec_info = batch.spec_info
|
||||
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
|
||||
batch.out_cache_loc = out_cache_loc
|
||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
|
||||
# Get forward batch
|
||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||
forward_batch
|
||||
)
|
||||
|
||||
if can_cuda_graph:
|
||||
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
||||
forward_batch
|
||||
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
|
||||
else:
|
||||
# Initialize attention backend
|
||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
# Run forward steps
|
||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||
|
||||
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
batch.sampling_info.is_all_greedy,
|
||||
)
|
||||
|
||||
# Free cache locations
|
||||
batch.token_to_kv_pool.free(out_cache_loc)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
return ret
|
||||
return ret, out_cache_loc
|
||||
|
||||
def draft_forward(self, forward_batch: ForwardBatch):
|
||||
# Parse args
|
||||
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output = self.model_runner.model.forward(
|
||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
if self.hot_token_id is not None:
|
||||
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
|
||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||
model_worker_batch, skip_sample=True
|
||||
)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
res = spec_info.verify(batch, logits_output)
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
return res + (model_worker_batch,)
|
||||
res: EagleVerifyOutput = spec_info.verify(
|
||||
batch, logits_output, self.token_to_kv_pool_allocator
|
||||
)
|
||||
|
||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
# Post process based on verified outputs.
|
||||
# Pick indices that we care (accepeted)
|
||||
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||
res.accepeted_indices_cpu
|
||||
]
|
||||
logits_output.hidden_states = logits_output.hidden_states[
|
||||
res.accepeted_indices_cpu
|
||||
]
|
||||
# Prepare the batch for the next draft forwards.
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.spec_info = res.draft_input
|
||||
|
||||
return logits_output, res, model_worker_batch
|
||||
|
||||
def forward_draft_extend(
|
||||
self,
|
||||
batch: ScheduleBatch,
|
||||
hidden_states: torch.Tensor,
|
||||
next_token_ids: List[int],
|
||||
):
|
||||
"""Run draft model extend. This API modifies the states of the batch.
|
||||
|
||||
Args:
|
||||
batch: The batch to run.
|
||||
hidden_states: Hidden states from the target model forward
|
||||
next_token_ids: Next token ids generated from the target forward.
|
||||
"""
|
||||
batch.spec_info = EagleDraftInput(
|
||||
hidden_states=hidden_states,
|
||||
verified_id=next_token_ids,
|
||||
)
|
||||
batch.spec_info.prepare_for_extend(batch)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
|
||||
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
||||
batch.req_to_token_pool = runner.req_to_token_pool
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||
seq_lens_backup = batch.seq_lens
|
||||
req_pool_indices_backup = batch.req_pool_indices
|
||||
|
||||
self._set_mem_pool(batch, self.model_runner)
|
||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||
# We don't need logprob for this extend.
|
||||
model_worker_batch = batch.get_model_worker_batch()
|
||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
||||
logits_output = self.model_runner.forward(forward_batch)
|
||||
self.capture_for_decode(logits_output, forward_batch)
|
||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
||||
forward_batch = ForwardBatch.init_new(
|
||||
model_worker_batch, self.draft_model_runner
|
||||
)
|
||||
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||
self._detect_nan_if_needed(logits_output)
|
||||
assert forward_batch.spec_info is batch.spec_info
|
||||
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||
|
||||
# Restore backup.
|
||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||
batch.forward_mode = ForwardMode.DECODE
|
||||
batch.seq_lens = seq_lens_backup
|
||||
batch.req_pool_indices = req_pool_indices_backup
|
||||
|
||||
def capture_for_decode(
|
||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
||||
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
||||
):
|
||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||
spec_info = forward_batch.spec_info
|
||||
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
spec_info.hidden_states = logits_output.hidden_states
|
||||
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||
draft_input.hidden_states = logits_output.hidden_states
|
||||
|
||||
# Don't support prefix share now.
|
||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
||||
if not isinstance(reqs, List):
|
||||
reqs = [reqs]
|
||||
for req in reqs:
|
||||
if req.rid not in self.finish_extend_len:
|
||||
continue
|
||||
req_len = (
|
||||
len(req.origin_input_ids)
|
||||
+ len(req.output_ids)
|
||||
- self.finish_extend_len[req.rid]
|
||||
- 1
|
||||
)
|
||||
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
|
||||
req.req_pool_idx
|
||||
][:req_len]
|
||||
self.model_runner.token_to_kv_pool.free(kv_indices)
|
||||
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
||||
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||
if self.use_nan_detection:
|
||||
logits = logits_output.next_token_logits
|
||||
if torch.any(torch.isnan(logits)):
|
||||
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||
|
||||
|
||||
def load_token_map(token_map_path: str) -> List[int]:
|
||||
|
||||
@@ -20,7 +20,3 @@ class SpeculativeAlgorithm(IntEnum):
|
||||
if name is not None:
|
||||
name = name.upper()
|
||||
return name_map[name]
|
||||
|
||||
|
||||
class SpecInfo:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user