[Eagle] Refactor eagle speculative decoding (#3986)
Co-authored-by: Ke Bao <ISPObaoke@163.com>
This commit is contained in:
@@ -230,7 +230,7 @@ def extend(reqs, model_runner):
|
|||||||
batch = ScheduleBatch.init_new(
|
batch = ScheduleBatch.init_new(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=model_runner.req_to_token_pool,
|
req_to_token_pool=model_runner.req_to_token_pool,
|
||||||
token_to_kv_pool=model_runner.token_to_kv_pool,
|
token_to_kv_pool_allocator=model_runner.token_to_kv_pool_allocator,
|
||||||
tree_cache=None,
|
tree_cache=None,
|
||||||
model_config=model_runner.model_config,
|
model_config=model_runner.model_config,
|
||||||
enable_overlap=False,
|
enable_overlap=False,
|
||||||
@@ -326,7 +326,7 @@ def latency_test_run_once(
|
|||||||
|
|
||||||
# Clear the pools.
|
# Clear the pools.
|
||||||
model_runner.req_to_token_pool.clear()
|
model_runner.req_to_token_pool.clear()
|
||||||
model_runner.token_to_kv_pool.clear()
|
model_runner.token_to_kv_pool_allocator.clear()
|
||||||
|
|
||||||
measurement_results = {
|
measurement_results = {
|
||||||
"run_name": run_name,
|
"run_name": run_name,
|
||||||
|
|||||||
@@ -20,14 +20,15 @@ import triton.language as tl
|
|||||||
|
|
||||||
from sglang.global_config import global_config
|
from sglang.global_config import global_config
|
||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.utils import is_flashinfer_available
|
from sglang.srt.utils import is_flashinfer_available
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.spec_info import SpecInfo
|
|
||||||
|
|
||||||
if is_flashinfer_available():
|
if is_flashinfer_available():
|
||||||
from flashinfer import (
|
from flashinfer import (
|
||||||
@@ -36,6 +37,7 @@ if is_flashinfer_available():
|
|||||||
BatchPrefillWithRaggedKVCacheWrapper,
|
BatchPrefillWithRaggedKVCacheWrapper,
|
||||||
)
|
)
|
||||||
from flashinfer.cascade import merge_state
|
from flashinfer.cascade import merge_state
|
||||||
|
from flashinfer.decode import PosEncodingMode
|
||||||
|
|
||||||
|
|
||||||
class WrapperDispatch(Enum):
|
class WrapperDispatch(Enum):
|
||||||
@@ -113,6 +115,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
device=model_runner.device,
|
device=model_runner.device,
|
||||||
)
|
)
|
||||||
self.workspace_buffer = global_workspace_buffer
|
self.workspace_buffer = global_workspace_buffer
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size
|
max_bs = model_runner.req_to_token_pool.size
|
||||||
if kv_indptr_buf is None:
|
if kv_indptr_buf is None:
|
||||||
self.kv_indptr = [
|
self.kv_indptr = [
|
||||||
@@ -133,10 +136,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
assert self.num_wrappers == 1
|
assert self.num_wrappers == 1
|
||||||
self.kv_last_page_len = kv_last_page_len_buf
|
self.kv_last_page_len = kv_last_page_len_buf
|
||||||
|
|
||||||
self.qo_indptr = [
|
if not self.skip_prefill:
|
||||||
torch.zeros((max_bs + 1,), dtype=torch.int32, device=model_runner.device)
|
self.qo_indptr = [
|
||||||
for _ in range(self.num_wrappers)
|
torch.zeros(
|
||||||
]
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||||
|
)
|
||||||
|
for _ in range(self.num_wrappers)
|
||||||
|
]
|
||||||
|
|
||||||
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
self.prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
|
||||||
self.workspace_buffer, "NHD"
|
self.workspace_buffer, "NHD"
|
||||||
@@ -276,7 +282,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens: torch.Tensor,
|
seq_lens: torch.Tensor,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
decode_wrappers = []
|
decode_wrappers = []
|
||||||
@@ -346,7 +352,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if forward_mode.is_decode_or_idle():
|
if forward_mode.is_decode_or_idle():
|
||||||
self.indices_updater_decode.update(
|
self.indices_updater_decode.update(
|
||||||
@@ -526,7 +532,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -538,7 +544,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||||
self.call_begin_forward(
|
self.call_begin_forward(
|
||||||
@@ -558,7 +564,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -592,7 +598,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
seq_lens_sum: int,
|
seq_lens_sum: int,
|
||||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -623,7 +629,7 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
paged_kernel_lens_sum: int,
|
paged_kernel_lens_sum: int,
|
||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
kv_start_idx: torch.Tensor,
|
kv_start_idx: torch.Tensor,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
bs = len(req_pool_indices)
|
bs = len(req_pool_indices)
|
||||||
@@ -642,9 +648,9 @@ class FlashInferIndicesUpdaterDecode:
|
|||||||
self.req_to_token.shape[1],
|
self.req_to_token.shape[1],
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(spec_info, EagleDraftInput)
|
||||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||||
bs = kv_indptr.shape[0] - 1
|
bs = kv_indptr.shape[0] - 1
|
||||||
|
|
||||||
wrapper.begin_forward(
|
wrapper.begin_forward(
|
||||||
kv_indptr,
|
kv_indptr,
|
||||||
kv_indices,
|
kv_indices,
|
||||||
@@ -699,7 +705,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
# Keep the signature for type checking. It will be assigned during runtime.
|
# Keep the signature for type checking. It will be assigned during runtime.
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
@@ -713,7 +719,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
if use_ragged:
|
if use_ragged:
|
||||||
paged_kernel_lens = prefix_lens
|
paged_kernel_lens = prefix_lens
|
||||||
@@ -746,7 +752,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -787,7 +793,7 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
encoder_lens: Optional[torch.Tensor],
|
encoder_lens: Optional[torch.Tensor],
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
for wrapper_id in range(2):
|
for wrapper_id in range(2):
|
||||||
if wrapper_id == 0:
|
if wrapper_id == 0:
|
||||||
@@ -829,10 +835,11 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
kv_indptr: torch.Tensor,
|
kv_indptr: torch.Tensor,
|
||||||
qo_indptr: torch.Tensor,
|
qo_indptr: torch.Tensor,
|
||||||
use_ragged: bool,
|
use_ragged: bool,
|
||||||
spec_info: Optional[SpecInfo],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
bs = len(req_pool_indices)
|
bs = len(seq_lens)
|
||||||
if spec_info is None:
|
if spec_info is None:
|
||||||
|
assert len(seq_lens) == len(req_pool_indices)
|
||||||
# Normal extend
|
# Normal extend
|
||||||
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
kv_indptr = kv_indptr[: bs + 1]
|
kv_indptr = kv_indptr[: bs + 1]
|
||||||
@@ -855,10 +862,14 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
qo_indptr = qo_indptr[: bs + 1]
|
qo_indptr = qo_indptr[: bs + 1]
|
||||||
custom_mask = None
|
custom_mask = None
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||||
|
spec_info, EagleVerifyInput
|
||||||
|
)
|
||||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
req_pool_indices,
|
req_pool_indices,
|
||||||
paged_kernel_lens,
|
paged_kernel_lens,
|
||||||
|
paged_kernel_lens_sum,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -890,6 +901,11 @@ class FlashInferIndicesUpdaterPrefill:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Use as a fast path to override the indptr in flashinfer's plan function
|
||||||
|
# This is used to remove some host-to-device copy overhead.
|
||||||
|
global global_override_indptr_cpu
|
||||||
|
|
||||||
|
|
||||||
class FlashInferMultiStepDraftBackend:
|
class FlashInferMultiStepDraftBackend:
|
||||||
"""
|
"""
|
||||||
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
Wrap multiple flashinfer attention backends as one for multiple consecutive
|
||||||
@@ -907,6 +923,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.topk = topk
|
self.topk = topk
|
||||||
self.speculative_num_steps = speculative_num_steps
|
self.speculative_num_steps = speculative_num_steps
|
||||||
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices
|
||||||
|
|
||||||
max_bs = model_runner.req_to_token_pool.size * self.topk
|
max_bs = model_runner.req_to_token_pool.size * self.topk
|
||||||
self.kv_indptr = torch.zeros(
|
self.kv_indptr = torch.zeros(
|
||||||
(
|
(
|
||||||
@@ -929,7 +946,9 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
kv_last_page_len_buf=self.kv_last_page_len,
|
kv_last_page_len_buf=self.kv_last_page_len,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.max_context_len = self.attn_backends[0].max_context_len
|
self.max_context_len = self.attn_backends[0].max_context_len
|
||||||
|
|
||||||
# Cached variables for generate_draft_decode_kv_indices
|
# Cached variables for generate_draft_decode_kv_indices
|
||||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||||
|
|
||||||
@@ -959,13 +978,23 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
triton.next_power_of_2(bs),
|
triton.next_power_of_2(bs),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
assert forward_batch.spec_info is not None
|
||||||
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
|
|
||||||
|
# Copy the kv_indptr once to avoid multiple device-to-host copies in flashinfer's plan.
|
||||||
|
indptr_cpu_whole = self.kv_indptr[:, : bs + 1].cpu()
|
||||||
|
global global_override_indptr_cpu
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps - 1):
|
for i in range(self.speculative_num_steps - 1):
|
||||||
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
|
||||||
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
|
||||||
: seq_lens_sum * self.topk + bs * (i + 1)
|
: seq_lens_sum * self.topk + bs * (i + 1)
|
||||||
]
|
]
|
||||||
|
global_override_indptr_cpu = indptr_cpu_whole[i]
|
||||||
call_fn(i, forward_batch)
|
call_fn(i, forward_batch)
|
||||||
|
|
||||||
|
global_override_indptr_cpu = None
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
kv_indices = torch.zeros(
|
kv_indices = torch.zeros(
|
||||||
(
|
(
|
||||||
@@ -977,6 +1006,8 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def call_fn(i, forward_batch):
|
def call_fn(i, forward_batch):
|
||||||
|
assert forward_batch.spec_info is not None
|
||||||
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
forward_batch.spec_info.kv_indptr = (
|
forward_batch.spec_info.kv_indptr = (
|
||||||
forward_batch.spec_info.kv_indptr.clone()
|
forward_batch.spec_info.kv_indptr.clone()
|
||||||
)
|
)
|
||||||
@@ -993,6 +1024,7 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
dtype=torch.int32,
|
dtype=torch.int32,
|
||||||
device="cuda",
|
device="cuda",
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(self.speculative_num_steps):
|
for i in range(self.speculative_num_steps):
|
||||||
self.attn_backends[i].init_cuda_graph_state(
|
self.attn_backends[i].init_cuda_graph_state(
|
||||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||||
@@ -1031,43 +1063,6 @@ class FlashInferMultiStepDraftBackend:
|
|||||||
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn)
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def create_flashinfer_kv_indices_triton(
|
|
||||||
req_to_token_ptr, # [max_batch, max_context_len]
|
|
||||||
req_pool_indices_ptr,
|
|
||||||
page_kernel_lens_ptr,
|
|
||||||
kv_indptr,
|
|
||||||
kv_start_idx,
|
|
||||||
kv_indices_ptr,
|
|
||||||
req_to_token_ptr_stride: tl.constexpr,
|
|
||||||
):
|
|
||||||
BLOCK_SIZE: tl.constexpr = 512
|
|
||||||
pid = tl.program_id(axis=0)
|
|
||||||
|
|
||||||
req_pool_index = tl.load(req_pool_indices_ptr + pid)
|
|
||||||
kv_indices_offset = tl.load(kv_indptr + pid)
|
|
||||||
|
|
||||||
kv_start = 0
|
|
||||||
kv_end = 0
|
|
||||||
if kv_start_idx:
|
|
||||||
kv_start = tl.load(kv_start_idx + pid).to(tl.int32)
|
|
||||||
kv_end = kv_start
|
|
||||||
kv_end += tl.load(page_kernel_lens_ptr + pid).to(tl.int32)
|
|
||||||
|
|
||||||
num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
|
|
||||||
for i in range(num_loop):
|
|
||||||
offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
|
|
||||||
mask = offset < kv_end - kv_start
|
|
||||||
data = tl.load(
|
|
||||||
req_to_token_ptr
|
|
||||||
+ req_pool_index * req_to_token_ptr_stride
|
|
||||||
+ kv_start
|
|
||||||
+ offset,
|
|
||||||
mask=mask,
|
|
||||||
)
|
|
||||||
tl.store(kv_indices_ptr + kv_indices_offset + offset, data, mask=mask)
|
|
||||||
|
|
||||||
|
|
||||||
def should_use_tensor_core(
|
def should_use_tensor_core(
|
||||||
kv_cache_dtype: torch.dtype,
|
kv_cache_dtype: torch.dtype,
|
||||||
num_attention_heads: int,
|
num_attention_heads: int,
|
||||||
@@ -1089,6 +1084,21 @@ def should_use_tensor_core(
|
|||||||
if env_override is not None:
|
if env_override is not None:
|
||||||
return env_override.lower() == "true"
|
return env_override.lower() == "true"
|
||||||
|
|
||||||
|
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
||||||
|
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
||||||
|
try:
|
||||||
|
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
||||||
|
|
||||||
|
if not _grouped_size_compiled_for_decode_kernels(
|
||||||
|
num_attention_heads,
|
||||||
|
num_kv_heads,
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
except (ImportError, AttributeError):
|
||||||
|
pass
|
||||||
|
|
||||||
# Calculate GQA group size
|
# Calculate GQA group size
|
||||||
gqa_group_size = num_attention_heads // num_kv_heads
|
gqa_group_size = num_attention_heads // num_kv_heads
|
||||||
|
|
||||||
@@ -1118,12 +1128,18 @@ def fast_decode_plan(
|
|||||||
sm_scale: Optional[float] = None,
|
sm_scale: Optional[float] = None,
|
||||||
rope_scale: Optional[float] = None,
|
rope_scale: Optional[float] = None,
|
||||||
rope_theta: Optional[float] = None,
|
rope_theta: Optional[float] = None,
|
||||||
**kwargs,
|
non_blocking: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend."""
|
"""
|
||||||
|
A faster version of BatchDecodeWithPagedKVCacheWrapper::plan used for FlashInferMultiStepDraftBackend.
|
||||||
|
Modifications:
|
||||||
|
- Remove unnecessary device-to-device copy for the cuda graph buffers.
|
||||||
|
- Remove unnecessary host-to-device copy for the metadata buffers.
|
||||||
|
"""
|
||||||
batch_size = len(last_page_len)
|
batch_size = len(last_page_len)
|
||||||
if logits_soft_cap is None:
|
if logits_soft_cap is None:
|
||||||
logits_soft_cap = 0.0
|
logits_soft_cap = 0.0
|
||||||
|
|
||||||
if self.is_cuda_graph_enabled:
|
if self.is_cuda_graph_enabled:
|
||||||
if batch_size != self._fixed_batch_size:
|
if batch_size != self._fixed_batch_size:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -1136,13 +1152,19 @@ def fast_decode_plan(
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The size of indices should be less than or equal to the allocated buffer"
|
"The size of indices should be less than or equal to the allocated buffer"
|
||||||
)
|
)
|
||||||
|
# Skip these copies
|
||||||
|
# self._paged_kv_indptr_buf.copy_(indptr)
|
||||||
|
# self._paged_kv_indices_buf[: len(indices)] = indices
|
||||||
|
# self._paged_kv_last_page_len_buf.copy_(last_page_len)
|
||||||
else:
|
else:
|
||||||
self._paged_kv_indptr_buf = indptr
|
self._paged_kv_indptr_buf = indptr
|
||||||
self._paged_kv_indices_buf = indices
|
self._paged_kv_indices_buf = indices
|
||||||
self._paged_kv_last_page_len_buf = last_page_len
|
self._paged_kv_last_page_len_buf = last_page_len
|
||||||
|
|
||||||
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
# NOTE(Zihao): the following tensors acts as placeholder to pass dtype info
|
||||||
if not q_data_type:
|
if not q_data_type:
|
||||||
q_data_type = data_type
|
q_data_type = data_type
|
||||||
|
|
||||||
if not hasattr(self, "empty_q_data"):
|
if not hasattr(self, "empty_q_data"):
|
||||||
self.empty_q_data = torch.empty(
|
self.empty_q_data = torch.empty(
|
||||||
0,
|
0,
|
||||||
@@ -1159,6 +1181,7 @@ def fast_decode_plan(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
self.last_page_len = torch.ones(32768, dtype=torch.int32)
|
||||||
|
|
||||||
empty_q_data = self.empty_q_data
|
empty_q_data = self.empty_q_data
|
||||||
empty_kv_cache = self.empty_kv_cache
|
empty_kv_cache = self.empty_kv_cache
|
||||||
stream = torch.cuda.current_stream()
|
stream = torch.cuda.current_stream()
|
||||||
|
|||||||
@@ -156,6 +156,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
spec_info.generate_attn_arg_prefill(
|
spec_info.generate_attn_arg_prefill(
|
||||||
forward_batch.req_pool_indices,
|
forward_batch.req_pool_indices,
|
||||||
forward_batch.seq_lens,
|
forward_batch.seq_lens,
|
||||||
|
None,
|
||||||
self.req_to_token,
|
self.req_to_token,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPoolHost
|
from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MHATokenToKVPoolHost
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -128,7 +128,7 @@ class HiCacheController:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
mem_pool_device: MHATokenToKVPool,
|
mem_pool_device: MHATokenToKVPool,
|
||||||
mem_pool_host: MLATokenToKVPoolHost,
|
mem_pool_host: MHATokenToKVPoolHost,
|
||||||
write_policy: str = "write_through_selective",
|
write_policy: str = "write_through_selective",
|
||||||
):
|
):
|
||||||
|
|
||||||
|
|||||||
@@ -44,18 +44,16 @@ from sglang.srt.configs.model_config import ModelConfig
|
|||||||
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
||||||
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
||||||
from sglang.srt.sampling.sampling_params import SamplingParams
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.server_args import ServerArgs
|
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
|
||||||
|
|
||||||
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
||||||
|
|
||||||
# Put some global args for easy access
|
# Put some global args for easy access
|
||||||
@@ -523,7 +521,7 @@ class ScheduleBatch:
|
|||||||
# Request, memory pool, and cache
|
# Request, memory pool, and cache
|
||||||
reqs: List[Req]
|
reqs: List[Req]
|
||||||
req_to_token_pool: ReqToTokenPool = None
|
req_to_token_pool: ReqToTokenPool = None
|
||||||
token_to_kv_pool: BaseTokenToKVPool = None
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
|
||||||
tree_cache: BasePrefixCache = None
|
tree_cache: BasePrefixCache = None
|
||||||
|
|
||||||
# Batch configs
|
# Batch configs
|
||||||
@@ -596,7 +594,7 @@ class ScheduleBatch:
|
|||||||
cls,
|
cls,
|
||||||
reqs: List[Req],
|
reqs: List[Req],
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool: ReqToTokenPool,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
model_config: ModelConfig,
|
model_config: ModelConfig,
|
||||||
enable_overlap: bool,
|
enable_overlap: bool,
|
||||||
@@ -606,7 +604,7 @@ class ScheduleBatch:
|
|||||||
return cls(
|
return cls(
|
||||||
reqs=reqs,
|
reqs=reqs,
|
||||||
req_to_token_pool=req_to_token_pool,
|
req_to_token_pool=req_to_token_pool,
|
||||||
token_to_kv_pool=token_to_kv_pool,
|
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||||
tree_cache=tree_cache,
|
tree_cache=tree_cache,
|
||||||
model_config=model_config,
|
model_config=model_config,
|
||||||
enable_overlap=enable_overlap,
|
enable_overlap=enable_overlap,
|
||||||
@@ -637,19 +635,19 @@ class ScheduleBatch:
|
|||||||
return req_pool_indices
|
return req_pool_indices
|
||||||
|
|
||||||
def alloc_token_slots(self, num_tokens: int):
|
def alloc_token_slots(self, num_tokens: int):
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
if self.tree_cache is not None:
|
if self.tree_cache is not None:
|
||||||
self.tree_cache.evict(num_tokens, self.token_to_kv_pool.free)
|
self.tree_cache.evict(num_tokens, self.token_to_kv_pool_allocator.free)
|
||||||
out_cache_loc = self.token_to_kv_pool.alloc(num_tokens)
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc(num_tokens)
|
||||||
|
|
||||||
if out_cache_loc is None:
|
if out_cache_loc is None:
|
||||||
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
phase_str = "Prefill" if self.forward_mode.is_extend() else "Decode"
|
||||||
logger.error(
|
logger.error(
|
||||||
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
f"{phase_str} out of memory. Try to lower your batch size.\n"
|
||||||
f"Try to allocate {num_tokens} tokens.\n"
|
f"Try to allocate {num_tokens} tokens.\n"
|
||||||
f"Avaliable tokens: {self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()}\n"
|
f"Avaliable tokens: {self.token_to_kv_pool_allocator.available_size() + self.tree_cache.evictable_size()}\n"
|
||||||
)
|
)
|
||||||
if self.tree_cache is not None:
|
if self.tree_cache is not None:
|
||||||
self.tree_cache.pretty_print()
|
self.tree_cache.pretty_print()
|
||||||
@@ -917,12 +915,12 @@ class ScheduleBatch:
|
|||||||
|
|
||||||
def check_decode_mem(self, buf_multiplier=1):
|
def check_decode_mem(self, buf_multiplier=1):
|
||||||
bs = len(self.reqs) * buf_multiplier
|
bs = len(self.reqs) * buf_multiplier
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
self.tree_cache.evict(bs, self.token_to_kv_pool.free)
|
self.tree_cache.evict(bs, self.token_to_kv_pool_allocator.free)
|
||||||
|
|
||||||
if self.token_to_kv_pool.available_size() >= bs:
|
if self.token_to_kv_pool_allocator.available_size() >= bs:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return False
|
return False
|
||||||
@@ -945,6 +943,10 @@ class ScheduleBatch:
|
|||||||
reverse=True,
|
reverse=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
retracted_reqs = []
|
||||||
|
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
||||||
|
first_iter = True
|
||||||
|
|
||||||
def get_required_tokens(num_reqs: int):
|
def get_required_tokens(num_reqs: int):
|
||||||
headroom_for_spec_decode = 0
|
headroom_for_spec_decode = 0
|
||||||
if server_args.speculative_algorithm:
|
if server_args.speculative_algorithm:
|
||||||
@@ -958,18 +960,15 @@ class ScheduleBatch:
|
|||||||
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
num_reqs * global_config.retract_decode_steps + headroom_for_spec_decode
|
||||||
)
|
)
|
||||||
|
|
||||||
retracted_reqs = []
|
|
||||||
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
|
||||||
first_iter = True
|
|
||||||
while (
|
while (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
< get_required_tokens(len(sorted_indices))
|
< get_required_tokens(len(sorted_indices))
|
||||||
or first_iter
|
or first_iter
|
||||||
):
|
):
|
||||||
if len(sorted_indices) == 1:
|
if len(sorted_indices) == 1:
|
||||||
# Corner case: only one request left
|
# Corner case: only one request left
|
||||||
assert (
|
assert (
|
||||||
self.token_to_kv_pool.available_size() > 0
|
self.token_to_kv_pool_allocator.available_size() > 0
|
||||||
), "No space left for only one request"
|
), "No space left for only one request"
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -983,7 +982,7 @@ class ScheduleBatch:
|
|||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, : seq_lens_cpu[idx]
|
req.req_pool_idx, : seq_lens_cpu[idx]
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool_allocator.free(token_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
del self.tree_cache.entries[req.rid]
|
del self.tree_cache.entries[req.rid]
|
||||||
else:
|
else:
|
||||||
@@ -992,7 +991,7 @@ class ScheduleBatch:
|
|||||||
token_indices = self.req_to_token_pool.req_to_token[
|
token_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool.free(token_indices)
|
self.token_to_kv_pool_allocator.free(token_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
|
|
||||||
# release the last node
|
# release the last node
|
||||||
@@ -1001,10 +1000,13 @@ class ScheduleBatch:
|
|||||||
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
||||||
residual_size = (
|
residual_size = (
|
||||||
len(sorted_indices) * global_config.retract_decode_steps
|
len(sorted_indices) * global_config.retract_decode_steps
|
||||||
- self.token_to_kv_pool.available_size()
|
- self.token_to_kv_pool_allocator.available_size()
|
||||||
)
|
)
|
||||||
residual_size = max(0, residual_size)
|
residual_size = max(0, residual_size)
|
||||||
self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
|
self.tree_cache.evict(
|
||||||
|
residual_size, self.token_to_kv_pool_allocator.free
|
||||||
|
)
|
||||||
|
|
||||||
req.reset_for_retract()
|
req.reset_for_retract()
|
||||||
|
|
||||||
self.filter_batch(keep_indices=sorted_indices)
|
self.filter_batch(keep_indices=sorted_indices)
|
||||||
@@ -1183,7 +1185,7 @@ class ScheduleBatch:
|
|||||||
if self.spec_info:
|
if self.spec_info:
|
||||||
self.spec_info.merge_batch(other.spec_info)
|
self.spec_info.merge_batch(other.spec_info)
|
||||||
|
|
||||||
def get_model_worker_batch(self):
|
def get_model_worker_batch(self) -> ModelWorkerBatch:
|
||||||
if self.forward_mode.is_decode_or_idle():
|
if self.forward_mode.is_decode_or_idle():
|
||||||
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
extend_seq_lens = extend_prefix_lens = extend_logprob_start_lens = None
|
||||||
else:
|
else:
|
||||||
@@ -1273,7 +1275,7 @@ class ModelWorkerBatch:
|
|||||||
req_pool_indices: torch.Tensor
|
req_pool_indices: torch.Tensor
|
||||||
# The sequence length
|
# The sequence length
|
||||||
seq_lens: torch.Tensor
|
seq_lens: torch.Tensor
|
||||||
# The indices of output tokens in the token_to_kv_pool
|
# The indices of output tokens in the token_to_kv_pool_allocator
|
||||||
out_cache_loc: torch.Tensor
|
out_cache_loc: torch.Tensor
|
||||||
|
|
||||||
# The sum of all sequence lengths
|
# The sum of all sequence lengths
|
||||||
|
|||||||
@@ -22,9 +22,13 @@ from typing import Dict, List, Optional, Set, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
from sglang.srt.managers.schedule_batch import (
|
||||||
|
Req,
|
||||||
|
ScheduleBatch,
|
||||||
|
global_server_args_dict,
|
||||||
|
)
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
|
||||||
|
|
||||||
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
# Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
|
||||||
@@ -75,7 +79,7 @@ class SchedulePolicy:
|
|||||||
|
|
||||||
# It is used to find the matching prefix for in-batch prefix caching.
|
# It is used to find the matching prefix for in-batch prefix caching.
|
||||||
self.waiting_queue_radix_tree = RadixCache(
|
self.waiting_queue_radix_tree = RadixCache(
|
||||||
req_to_token_pool=None, token_to_kv_pool=None, disable=False
|
req_to_token_pool=None, token_to_kv_pool_allocator=None, disable=False
|
||||||
)
|
)
|
||||||
|
|
||||||
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
def calc_priority(self, waiting_queue: List[Req]) -> bool:
|
||||||
@@ -251,7 +255,7 @@ class PrefillAdder:
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tree_cache: BasePrefixCache,
|
tree_cache: BasePrefixCache,
|
||||||
token_to_kv_pool: BaseTokenToKVPool,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
running_batch: ScheduleBatch,
|
running_batch: ScheduleBatch,
|
||||||
new_token_ratio: float,
|
new_token_ratio: float,
|
||||||
rem_input_tokens: int,
|
rem_input_tokens: int,
|
||||||
@@ -259,7 +263,7 @@ class PrefillAdder:
|
|||||||
mixed_with_decode_tokens: int = 0,
|
mixed_with_decode_tokens: int = 0,
|
||||||
):
|
):
|
||||||
self.tree_cache = tree_cache
|
self.tree_cache = tree_cache
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.running_batch = running_batch
|
self.running_batch = running_batch
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
|
||||||
@@ -291,7 +295,7 @@ class PrefillAdder:
|
|||||||
@property
|
@property
|
||||||
def rem_total_tokens(self):
|
def rem_total_tokens(self):
|
||||||
return (
|
return (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
- self.rem_total_token_offset
|
- self.rem_total_token_offset
|
||||||
)
|
)
|
||||||
@@ -299,7 +303,7 @@ class PrefillAdder:
|
|||||||
@property
|
@property
|
||||||
def cur_rem_tokens(self):
|
def cur_rem_tokens(self):
|
||||||
return (
|
return (
|
||||||
self.token_to_kv_pool.available_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
+ self.tree_cache.evictable_size()
|
+ self.tree_cache.evictable_size()
|
||||||
- self.cur_rem_token_offset
|
- self.cur_rem_token_offset
|
||||||
)
|
)
|
||||||
@@ -332,7 +336,6 @@ class PrefillAdder:
|
|||||||
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
|
|
||||||
self._prefill_one_req(
|
self._prefill_one_req(
|
||||||
0,
|
0,
|
||||||
req.extend_input_len,
|
req.extend_input_len,
|
||||||
@@ -400,8 +403,8 @@ class PrefillAdder:
|
|||||||
tokens_freed += tokens_occupied
|
tokens_freed += tokens_occupied
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.rem_chunk_tokens is None
|
self.rem_chunk_tokens is None # chunked prefill is disabled
|
||||||
or req.extend_input_len <= self.rem_chunk_tokens
|
or req.extend_input_len <= self.rem_chunk_tokens # it is the last chunk
|
||||||
):
|
):
|
||||||
# Non-chunked prefill
|
# Non-chunked prefill
|
||||||
self.can_run_list.append(req)
|
self.can_run_list.append(req)
|
||||||
@@ -411,10 +414,11 @@ class PrefillAdder:
|
|||||||
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.rem_chunk_tokens == 0:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
if trunc_len == 0:
|
|
||||||
return AddReqResult.OTHER
|
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[:trunc_len]
|
req.fill_ids = req.fill_ids[:trunc_len]
|
||||||
@@ -457,10 +461,11 @@ class PrefillAdder:
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
if self.rem_chunk_tokens == 0:
|
||||||
|
return AddReqResult.OTHER
|
||||||
|
|
||||||
# Chunked prefill
|
# Chunked prefill
|
||||||
trunc_len = self.rem_chunk_tokens
|
trunc_len = self.rem_chunk_tokens
|
||||||
if trunc_len == 0:
|
|
||||||
return AddReqResult.OTHER
|
|
||||||
|
|
||||||
req.extend_input_len = trunc_len
|
req.extend_input_len = trunc_len
|
||||||
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ class Scheduler:
|
|||||||
self.server_args.speculative_num_draft_tokens
|
self.server_args.speculative_num_draft_tokens
|
||||||
+ (
|
+ (
|
||||||
self.server_args.speculative_eagle_topk
|
self.server_args.speculative_eagle_topk
|
||||||
* self.server_args.speculative_num_steps
|
* self.server_args.speculative_num_draft_tokens
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not self.spec_algorithm.is_none()
|
if not self.spec_algorithm.is_none()
|
||||||
@@ -309,7 +309,9 @@ class Scheduler:
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Init memory pool and cache
|
# Init memory pool and cache
|
||||||
self.req_to_token_pool, self.token_to_kv_pool = self.tp_worker.get_memory_pool()
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
self.tp_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
server_args.chunked_prefill_size is not None
|
server_args.chunked_prefill_size is not None
|
||||||
@@ -317,18 +319,18 @@ class Scheduler:
|
|||||||
):
|
):
|
||||||
self.tree_cache = ChunkCache(
|
self.tree_cache = ChunkCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool=self.token_to_kv_pool,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if self.enable_hierarchical_cache:
|
if self.enable_hierarchical_cache:
|
||||||
self.tree_cache = HiRadixCache(
|
self.tree_cache = HiRadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool=self.token_to_kv_pool,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self.tree_cache = RadixCache(
|
self.tree_cache = RadixCache(
|
||||||
req_to_token_pool=self.req_to_token_pool,
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
token_to_kv_pool=self.token_to_kv_pool,
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
disable=server_args.disable_radix_cache,
|
disable=server_args.disable_radix_cache,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -458,7 +460,6 @@ class Scheduler:
|
|||||||
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
(ResumeMemoryOccupationReqInput, self.resume_memory_occupation),
|
||||||
(ProfileReq, self.profile),
|
(ProfileReq, self.profile),
|
||||||
(GetInternalStateReq, self.get_internal_state),
|
(GetInternalStateReq, self.get_internal_state),
|
||||||
(SetInternalStateReq, self.set_internal_state),
|
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -809,7 +810,8 @@ class Scheduler:
|
|||||||
running_bs: int,
|
running_bs: int,
|
||||||
):
|
):
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
self._largest_prefill_len = max(
|
self._largest_prefill_len = max(
|
||||||
self._largest_prefill_len, adder.log_input_tokens
|
self._largest_prefill_len, adder.log_input_tokens
|
||||||
@@ -844,7 +846,8 @@ class Scheduler:
|
|||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
num_running_reqs = len(self.running_batch.reqs) if self.running_batch else 0
|
||||||
num_used = self.max_total_num_tokens - (
|
num_used = self.max_total_num_tokens - (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
|
|
||||||
if RECORD_STEP_TIME:
|
if RECORD_STEP_TIME:
|
||||||
@@ -894,7 +897,8 @@ class Scheduler:
|
|||||||
|
|
||||||
def check_memory(self):
|
def check_memory(self):
|
||||||
available_size = (
|
available_size = (
|
||||||
self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
|
self.token_to_kv_pool_allocator.available_size()
|
||||||
|
+ self.tree_cache.evictable_size()
|
||||||
)
|
)
|
||||||
protected_size = self.tree_cache.protected_size()
|
protected_size = self.tree_cache.protected_size()
|
||||||
memory_leak = available_size != (
|
memory_leak = available_size != (
|
||||||
@@ -999,7 +1003,7 @@ class Scheduler:
|
|||||||
# Prefill policy
|
# Prefill policy
|
||||||
adder = PrefillAdder(
|
adder = PrefillAdder(
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool_allocator,
|
||||||
self.running_batch,
|
self.running_batch,
|
||||||
self.new_token_ratio,
|
self.new_token_ratio,
|
||||||
self.max_prefill_tokens,
|
self.max_prefill_tokens,
|
||||||
@@ -1099,7 +1103,7 @@ class Scheduler:
|
|||||||
new_batch = ScheduleBatch.init_new(
|
new_batch = ScheduleBatch.init_new(
|
||||||
can_run_list,
|
can_run_list,
|
||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool_allocator,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
@@ -1143,8 +1147,6 @@ class Scheduler:
|
|||||||
|
|
||||||
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
retracted_reqs, new_token_ratio = batch.retract_decode(self.server_args)
|
||||||
self.new_token_ratio = new_token_ratio
|
self.new_token_ratio = new_token_ratio
|
||||||
if self.draft_worker:
|
|
||||||
self.draft_worker.finish_request(retracted_reqs)
|
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"Decode out of memory happened. "
|
"Decode out of memory happened. "
|
||||||
@@ -1184,11 +1186,12 @@ class Scheduler:
|
|||||||
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
logits_output, next_token_ids = self.tp_worker.forward_batch_generation(
|
||||||
model_worker_batch
|
model_worker_batch
|
||||||
)
|
)
|
||||||
|
bid = model_worker_batch.bid
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
logits_output,
|
logits_output,
|
||||||
next_token_ids,
|
next_token_ids,
|
||||||
model_worker_batch,
|
bid,
|
||||||
num_accepted_tokens,
|
num_accepted_tokens,
|
||||||
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
) = self.draft_worker.forward_batch_speculative_generation(batch)
|
||||||
self.spec_num_total_accepted_tokens += (
|
self.spec_num_total_accepted_tokens += (
|
||||||
@@ -1214,7 +1217,7 @@ class Scheduler:
|
|||||||
next_token_ids=next_token_ids,
|
next_token_ids=next_token_ids,
|
||||||
extend_input_len_per_req=extend_input_len_per_req,
|
extend_input_len_per_req=extend_input_len_per_req,
|
||||||
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
extend_logprob_start_len_per_req=extend_logprob_start_len_per_req,
|
||||||
bid=model_worker_batch.bid,
|
bid=bid,
|
||||||
)
|
)
|
||||||
else: # embedding or reward model
|
else: # embedding or reward model
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
@@ -1230,6 +1233,7 @@ class Scheduler:
|
|||||||
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
result: Union[GenerationBatchResult, EmbeddingBatchResult],
|
||||||
):
|
):
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
|
assert isinstance(result, GenerationBatchResult)
|
||||||
self.process_batch_result_decode(batch, result)
|
self.process_batch_result_decode(batch, result)
|
||||||
if batch.is_empty():
|
if batch.is_empty():
|
||||||
self.running_batch = None
|
self.running_batch = None
|
||||||
@@ -1302,7 +1306,7 @@ class Scheduler:
|
|||||||
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
if self.is_mixed_chunk and self.enable_overlap and req.finished():
|
||||||
# Free the one delayed token for the mixed decode batch
|
# Free the one delayed token for the mixed decode batch
|
||||||
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
j = len(batch.out_cache_loc) - len(batch.reqs) + i
|
||||||
self.token_to_kv_pool.free(batch.out_cache_loc[j : j + 1])
|
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[j : j + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if req.is_chunked <= 0:
|
if req.is_chunked <= 0:
|
||||||
@@ -1420,23 +1424,27 @@ class Scheduler:
|
|||||||
self.num_generated_tokens += len(batch.reqs)
|
self.num_generated_tokens += len(batch.reqs)
|
||||||
|
|
||||||
if self.enable_overlap:
|
if self.enable_overlap:
|
||||||
|
assert batch.spec_algorithm.is_none()
|
||||||
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
logits_output, next_token_ids = self.tp_worker.resolve_batch_result(bid)
|
||||||
next_token_logprobs = logits_output.next_token_logprobs
|
next_token_logprobs = logits_output.next_token_logprobs
|
||||||
else:
|
elif batch.spec_algorithm.is_none():
|
||||||
|
# spec decoding handles output logprobs inside verify process.
|
||||||
next_token_ids = next_token_ids.tolist()
|
next_token_ids = next_token_ids.tolist()
|
||||||
if batch.return_logprob:
|
if batch.return_logprob:
|
||||||
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
next_token_logprobs = logits_output.next_token_logprobs.tolist()
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_begin()
|
self.token_to_kv_pool_allocator.free_group_begin()
|
||||||
|
|
||||||
# Check finish condition
|
# Check finish condition
|
||||||
|
# NOTE: the length of reqs and next_token_ids don't match if it is spec decoding.
|
||||||
|
# We should ignore using next_token_ids for spec decoding cases.
|
||||||
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
|
||||||
if req.is_retracted:
|
if req.is_retracted:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if self.enable_overlap and req.finished():
|
if self.enable_overlap and req.finished():
|
||||||
# Free the one delayed token
|
# Free the one delayed token
|
||||||
self.token_to_kv_pool.free(batch.out_cache_loc[i : i + 1])
|
self.token_to_kv_pool_allocator.free(batch.out_cache_loc[i : i + 1])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if batch.spec_algorithm.is_none():
|
if batch.spec_algorithm.is_none():
|
||||||
@@ -1479,7 +1487,7 @@ class Scheduler:
|
|||||||
batch.next_batch_sampling_info.sampling_info_done.set()
|
batch.next_batch_sampling_info.sampling_info_done.set()
|
||||||
self.stream_output(batch.reqs, batch.return_logprob)
|
self.stream_output(batch.reqs, batch.return_logprob)
|
||||||
|
|
||||||
self.token_to_kv_pool.free_group_end()
|
self.token_to_kv_pool_allocator.free_group_end()
|
||||||
|
|
||||||
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
self.forward_ct_decode = (self.forward_ct_decode + 1) % (1 << 30)
|
||||||
if (
|
if (
|
||||||
@@ -1718,9 +1726,6 @@ class Scheduler:
|
|||||||
and not self.model_config.is_multimodal_gen
|
and not self.model_config.is_multimodal_gen
|
||||||
)
|
)
|
||||||
):
|
):
|
||||||
if self.draft_worker and req.finished():
|
|
||||||
self.draft_worker.finish_request(req)
|
|
||||||
|
|
||||||
rids.append(req.rid)
|
rids.append(req.rid)
|
||||||
finished_reasons.append(
|
finished_reasons.append(
|
||||||
req.finished_reason.to_json() if req.finished_reason else None
|
req.finished_reason.to_json() if req.finished_reason else None
|
||||||
@@ -1860,7 +1865,7 @@ class Scheduler:
|
|||||||
idle_batch = ScheduleBatch.init_new(
|
idle_batch = ScheduleBatch.init_new(
|
||||||
[],
|
[],
|
||||||
self.req_to_token_pool,
|
self.req_to_token_pool,
|
||||||
self.token_to_kv_pool,
|
self.token_to_kv_pool_allocator,
|
||||||
self.tree_cache,
|
self.tree_cache,
|
||||||
self.model_config,
|
self.model_config,
|
||||||
self.enable_overlap,
|
self.enable_overlap,
|
||||||
@@ -1916,11 +1921,11 @@ class Scheduler:
|
|||||||
if self.grammar_backend:
|
if self.grammar_backend:
|
||||||
self.grammar_backend.reset()
|
self.grammar_backend.reset()
|
||||||
self.req_to_token_pool.clear()
|
self.req_to_token_pool.clear()
|
||||||
self.token_to_kv_pool.clear()
|
self.token_to_kv_pool_allocator.clear()
|
||||||
|
|
||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
self.draft_worker.model_runner.req_to_token_pool.clear()
|
self.draft_worker.model_runner.req_to_token_pool.clear()
|
||||||
self.draft_worker.model_runner.token_to_kv_pool.clear()
|
self.draft_worker.model_runner.token_to_kv_pool_allocator.clear()
|
||||||
|
|
||||||
self.num_generated_tokens = 0
|
self.num_generated_tokens = 0
|
||||||
self.forward_ct_decode = 0
|
self.forward_ct_decode = 0
|
||||||
|
|||||||
@@ -82,8 +82,6 @@ from sglang.srt.managers.io_struct import (
|
|||||||
ResumeMemoryOccupationReqInput,
|
ResumeMemoryOccupationReqInput,
|
||||||
ResumeMemoryOccupationReqOutput,
|
ResumeMemoryOccupationReqOutput,
|
||||||
SessionParams,
|
SessionParams,
|
||||||
SetInternalStateReq,
|
|
||||||
SetInternalStateReqOutput,
|
|
||||||
TokenizedEmbeddingReqInput,
|
TokenizedEmbeddingReqInput,
|
||||||
TokenizedGenerateReqInput,
|
TokenizedGenerateReqInput,
|
||||||
UpdateWeightFromDiskReqInput,
|
UpdateWeightFromDiskReqInput,
|
||||||
@@ -257,9 +255,6 @@ class TokenizerManager:
|
|||||||
self.get_internal_state_communicator = _Communicator(
|
self.get_internal_state_communicator = _Communicator(
|
||||||
self.send_to_scheduler, server_args.dp_size
|
self.send_to_scheduler, server_args.dp_size
|
||||||
)
|
)
|
||||||
self.set_internal_state_communicator = _Communicator(
|
|
||||||
self.send_to_scheduler, server_args.dp_size
|
|
||||||
)
|
|
||||||
|
|
||||||
self._result_dispatcher = TypeBasedDispatcher(
|
self._result_dispatcher = TypeBasedDispatcher(
|
||||||
[
|
[
|
||||||
@@ -309,10 +304,6 @@ class TokenizerManager:
|
|||||||
GetInternalStateReqOutput,
|
GetInternalStateReqOutput,
|
||||||
self.get_internal_state_communicator.handle_recv,
|
self.get_internal_state_communicator.handle_recv,
|
||||||
),
|
),
|
||||||
(
|
|
||||||
SetInternalStateReqOutput,
|
|
||||||
self.set_internal_state_communicator.handle_recv,
|
|
||||||
),
|
|
||||||
(HealthCheckOutput, lambda x: None),
|
(HealthCheckOutput, lambda x: None),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
@@ -774,14 +765,6 @@ class TokenizerManager:
|
|||||||
)
|
)
|
||||||
return res[0].internal_state
|
return res[0].internal_state
|
||||||
|
|
||||||
async def set_internal_state(
|
|
||||||
self, obj: SetInternalStateReq
|
|
||||||
) -> SetInternalStateReqOutput:
|
|
||||||
res: List[SetInternalStateReqOutput] = (
|
|
||||||
await self.set_internal_state_communicator(obj)
|
|
||||||
)
|
|
||||||
return res[0]
|
|
||||||
|
|
||||||
def get_log_request_metadata(self):
|
def get_log_request_metadata(self):
|
||||||
max_length = None
|
max_length = None
|
||||||
skip_names = None
|
skip_names = None
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
|
|||||||
UpdateWeightsFromTensorReqInput,
|
UpdateWeightsFromTensorReqInput,
|
||||||
)
|
)
|
||||||
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
||||||
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.server_args import ServerArgs
|
from sglang.srt.server_args import ServerArgs
|
||||||
@@ -49,6 +50,8 @@ class TpModelWorker:
|
|||||||
dp_rank: Optional[int],
|
dp_rank: Optional[int],
|
||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.tp_rank = tp_rank
|
self.tp_rank = tp_rank
|
||||||
@@ -77,6 +80,8 @@ class TpModelWorker:
|
|||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
server_args=server_args,
|
server_args=server_args,
|
||||||
is_draft_worker=is_draft_worker,
|
is_draft_worker=is_draft_worker,
|
||||||
|
req_to_token_pool=req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
if server_args.skip_tokenizer_init:
|
if server_args.skip_tokenizer_init:
|
||||||
self.tokenizer = self.processor = None
|
self.tokenizer = self.processor = None
|
||||||
@@ -154,7 +159,7 @@ class TpModelWorker:
|
|||||||
def get_memory_pool(self):
|
def get_memory_pool(self):
|
||||||
return (
|
return (
|
||||||
self.model_runner.req_to_token_pool,
|
self.model_runner.req_to_token_pool,
|
||||||
self.model_runner.token_to_kv_pool,
|
self.model_runner.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_batch_generation(
|
def forward_batch_generation(
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ class TpModelWorkerClient:
|
|||||||
def get_memory_pool(self):
|
def get_memory_pool(self):
|
||||||
return (
|
return (
|
||||||
self.worker.model_runner.req_to_token_pool,
|
self.worker.model_runner.req_to_token_pool,
|
||||||
self.worker.model_runner.token_to_kv_pool,
|
self.worker.model_runner.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_thread_func(self):
|
def forward_thread_func(self):
|
||||||
|
|||||||
@@ -1,13 +1,12 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
"""Cache for chunked prefill, used when RadixCache is disabled."""
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -21,11 +20,13 @@ class ChunkCacheEntry:
|
|||||||
|
|
||||||
class ChunkCache(BasePrefixCache):
|
class ChunkCache(BasePrefixCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
|
self,
|
||||||
|
req_to_token_pool: ReqToTokenPool,
|
||||||
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
):
|
):
|
||||||
self.disable = True
|
self.disable = True
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.entries: Dict[str, ChunkCacheEntry] = {}
|
self.entries: Dict[str, ChunkCacheEntry] = {}
|
||||||
|
|
||||||
self.reset()
|
self.reset()
|
||||||
@@ -51,7 +52,7 @@ class ChunkCache(BasePrefixCache):
|
|||||||
req.req_pool_idx, :token_id_len
|
req.req_pool_idx, :token_id_len
|
||||||
]
|
]
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
|
|
||||||
if req.rid in self.entries:
|
if req.rid in self.entries:
|
||||||
del self.entries[req.rid]
|
del self.entries[req.rid]
|
||||||
@@ -91,3 +92,6 @@ class ChunkCache(BasePrefixCache):
|
|||||||
|
|
||||||
def protected_size(self):
|
def protected_size(self):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
def pretty_print(self):
|
||||||
|
return ""
|
||||||
|
|||||||
@@ -7,8 +7,8 @@ import torch
|
|||||||
|
|
||||||
from sglang.srt.managers.cache_controller import HiCacheController
|
from sglang.srt.managers.cache_controller import HiCacheController
|
||||||
from sglang.srt.mem_cache.memory_pool import (
|
from sglang.srt.mem_cache.memory_pool import (
|
||||||
BaseTokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPoolHost,
|
MHATokenToKVPoolHost,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
)
|
)
|
||||||
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode, _key_match
|
||||||
@@ -21,9 +21,9 @@ class HiRadixCache(RadixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool: BaseTokenToKVPool,
|
token_to_kv_pool: MHATokenToKVPool,
|
||||||
):
|
):
|
||||||
self.token_to_kv_pool_host = MLATokenToKVPoolHost(token_to_kv_pool)
|
self.token_to_kv_pool_host = MHATokenToKVPoolHost(token_to_kv_pool)
|
||||||
self.cache_controller = HiCacheController(
|
self.cache_controller = HiCacheController(
|
||||||
token_to_kv_pool, self.token_to_kv_pool_host
|
token_to_kv_pool, self.token_to_kv_pool_host
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,9 +20,12 @@ Memory pool.
|
|||||||
|
|
||||||
SGLang has two levels of memory pool.
|
SGLang has two levels of memory pool.
|
||||||
ReqToTokenPool maps a a request to its token locations.
|
ReqToTokenPool maps a a request to its token locations.
|
||||||
BaseTokenToKVPool maps a token location to its KV cache data.
|
TokenToKVPoolAllocator maps a token location to its KV cache data.
|
||||||
|
KVCache actually holds the physical kv cache. Allocation indices are allocated
|
||||||
|
by TokenToKVPoolAllocator
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import abc
|
||||||
import logging
|
import logging
|
||||||
import threading
|
import threading
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
@@ -89,7 +92,7 @@ class ReqToTokenPool:
|
|||||||
self.free_slots = list(range(self.size))
|
self.free_slots = list(range(self.size))
|
||||||
|
|
||||||
|
|
||||||
class BaseTokenToKVPool:
|
class TokenToKVPoolAllocator:
|
||||||
"""A memory pool that maps a token location to its kv cache data."""
|
"""A memory pool that maps a token location to its kv cache data."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -100,11 +103,6 @@ class BaseTokenToKVPool:
|
|||||||
):
|
):
|
||||||
self.size = size
|
self.size = size
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
|
||||||
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
|
||||||
self.store_dtype = torch.uint8
|
|
||||||
else:
|
|
||||||
self.store_dtype = dtype
|
|
||||||
self.device = device
|
self.device = device
|
||||||
|
|
||||||
self.free_slots = None
|
self.free_slots = None
|
||||||
@@ -148,15 +146,22 @@ class BaseTokenToKVPool:
|
|||||||
self.is_in_free_group = False
|
self.is_in_free_group = False
|
||||||
self.free_group = []
|
self.free_group = []
|
||||||
|
|
||||||
|
|
||||||
|
class KVCache(abc.ABC):
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
def set_kv_buffer(
|
def set_kv_buffer(
|
||||||
self,
|
self,
|
||||||
layer: RadixAttention,
|
layer: RadixAttention,
|
||||||
@@ -167,7 +172,7 @@ class BaseTokenToKVPool:
|
|||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
class MHATokenToKVPool(BaseTokenToKVPool):
|
class MHATokenToKVPool(KVCache):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
@@ -179,8 +184,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|||||||
device: str,
|
device: str,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
):
|
):
|
||||||
super().__init__(size, dtype, device)
|
self.size = size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||||
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||||
|
self.store_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.store_dtype = dtype
|
||||||
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
)
|
)
|
||||||
@@ -297,7 +308,7 @@ def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
|||||||
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPool(BaseTokenToKVPool):
|
class MLATokenToKVPool(KVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
@@ -308,8 +319,14 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|||||||
device: str,
|
device: str,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
):
|
):
|
||||||
super().__init__(size, dtype, device)
|
self.size = size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||||
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||||
|
self.store_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.store_dtype = dtype
|
||||||
self.kv_lora_rank = kv_lora_rank
|
self.kv_lora_rank = kv_lora_rank
|
||||||
|
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
@@ -356,7 +373,7 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|||||||
self.kv_buffer[layer_id][loc] = cache_k
|
self.kv_buffer[layer_id][loc] = cache_k
|
||||||
|
|
||||||
|
|
||||||
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
class DoubleSparseTokenToKVPool(KVCache):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
size: int,
|
size: int,
|
||||||
@@ -368,8 +385,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|||||||
heavy_channel_num: int,
|
heavy_channel_num: int,
|
||||||
enable_memory_saver: bool,
|
enable_memory_saver: bool,
|
||||||
):
|
):
|
||||||
super().__init__(size, dtype, device)
|
self.size = size
|
||||||
|
self.dtype = dtype
|
||||||
|
self.device = device
|
||||||
|
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
||||||
|
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
||||||
|
self.store_dtype = torch.uint8
|
||||||
|
else:
|
||||||
|
self.store_dtype = dtype
|
||||||
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
||||||
enable=enable_memory_saver
|
enable=enable_memory_saver
|
||||||
)
|
)
|
||||||
@@ -437,12 +460,12 @@ def synchronized(func):
|
|||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
class MLATokenToKVPoolHost:
|
class MHATokenToKVPoolHost:
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device_pool: MHATokenToKVPool,
|
device_pool: MHATokenToKVPool,
|
||||||
host_to_device_ratio: float = 4.0,
|
host_to_device_ratio: float = 2.0,
|
||||||
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
pin_memory: bool = False, # no need to use pin memory with the double buffering
|
||||||
device: str = "cpu",
|
device: str = "cpu",
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -26,8 +26,9 @@ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||||
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import Req
|
from sglang.srt.managers.schedule_batch import Req
|
||||||
@@ -79,11 +80,11 @@ class RadixCache(BasePrefixCache):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
req_to_token_pool: ReqToTokenPool,
|
req_to_token_pool: ReqToTokenPool,
|
||||||
token_to_kv_pool: BaseTokenToKVPool,
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
disable: bool = False,
|
disable: bool = False,
|
||||||
):
|
):
|
||||||
self.req_to_token_pool = req_to_token_pool
|
self.req_to_token_pool = req_to_token_pool
|
||||||
self.token_to_kv_pool = token_to_kv_pool
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
self.disable = disable
|
self.disable = disable
|
||||||
self.reset()
|
self.reset()
|
||||||
|
|
||||||
@@ -139,7 +140,7 @@ class RadixCache(BasePrefixCache):
|
|||||||
kv_indices = self.req_to_token_pool.req_to_token[
|
kv_indices = self.req_to_token_pool.req_to_token[
|
||||||
req.req_pool_idx, :token_ids_len
|
req.req_pool_idx, :token_ids_len
|
||||||
]
|
]
|
||||||
self.token_to_kv_pool.free(kv_indices)
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -151,7 +152,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||||
|
)
|
||||||
|
|
||||||
# Remove req slot release the cache lock
|
# Remove req slot release the cache lock
|
||||||
self.req_to_token_pool.free(req.req_pool_idx)
|
self.req_to_token_pool.free(req.req_pool_idx)
|
||||||
@@ -171,7 +174,9 @@ class RadixCache(BasePrefixCache):
|
|||||||
|
|
||||||
# Radix Cache takes one ref in memory pool
|
# Radix Cache takes one ref in memory pool
|
||||||
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
||||||
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
self.token_to_kv_pool_allocator.free(
|
||||||
|
kv_indices[len(req.prefix_indices) : new_prefix_len]
|
||||||
|
)
|
||||||
|
|
||||||
# The prefix indices could be updated, reuse it
|
# The prefix indices could be updated, reuse it
|
||||||
new_indices, new_last_node = self.match_prefix(token_ids)
|
new_indices, new_last_node = self.match_prefix(token_ids)
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|||||||
MHATokenToKVPool,
|
MHATokenToKVPool,
|
||||||
MLATokenToKVPool,
|
MLATokenToKVPool,
|
||||||
ReqToTokenPool,
|
ReqToTokenPool,
|
||||||
|
TokenToKVPoolAllocator,
|
||||||
)
|
)
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
@@ -98,6 +99,8 @@ class ModelRunner:
|
|||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
server_args: ServerArgs,
|
server_args: ServerArgs,
|
||||||
is_draft_worker: bool = False,
|
is_draft_worker: bool = False,
|
||||||
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
||||||
|
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
||||||
):
|
):
|
||||||
# Parse args
|
# Parse args
|
||||||
self.model_config = model_config
|
self.model_config = model_config
|
||||||
@@ -115,6 +118,8 @@ class ModelRunner:
|
|||||||
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
|
self.req_to_token_pool = req_to_token_pool
|
||||||
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
||||||
|
|
||||||
# Model-specific adjustment
|
# Model-specific adjustment
|
||||||
if (
|
if (
|
||||||
@@ -257,8 +262,8 @@ class ModelRunner:
|
|||||||
|
|
||||||
def init_torch_distributed(self):
|
def init_torch_distributed(self):
|
||||||
logger.info("Init torch distributed begin.")
|
logger.info("Init torch distributed begin.")
|
||||||
|
|
||||||
torch.get_device_module(self.device).set_device(self.gpu_id)
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
||||||
|
|
||||||
if self.device == "cuda":
|
if self.device == "cuda":
|
||||||
backend = "nccl"
|
backend = "nccl"
|
||||||
elif self.device == "xpu":
|
elif self.device == "xpu":
|
||||||
@@ -660,12 +665,25 @@ class ModelRunner:
|
|||||||
if not self.spec_algorithm.is_none():
|
if not self.spec_algorithm.is_none():
|
||||||
if self.is_draft_worker:
|
if self.is_draft_worker:
|
||||||
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||||
|
max_num_reqs = self.server_args.max_num_reqs
|
||||||
else:
|
else:
|
||||||
|
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
||||||
|
# can be concurrently allocated, so we should give a headroom for it.
|
||||||
self.server_args.draft_runner_cache_size = (
|
self.server_args.draft_runner_cache_size = (
|
||||||
self.max_total_num_tokens
|
self.max_total_num_tokens
|
||||||
+ max_num_reqs * self.server_args.speculative_num_steps
|
# draft
|
||||||
|
+ max_num_reqs
|
||||||
|
* self.server_args.speculative_num_steps
|
||||||
|
* self.server_args.speculative_eagle_topk
|
||||||
|
# verify
|
||||||
|
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
||||||
|
# buffer
|
||||||
+ 100
|
+ 100
|
||||||
)
|
)
|
||||||
|
# Target worker and draft worker shares the same indices for the
|
||||||
|
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
||||||
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
||||||
|
self.server_args.max_num_reqs = max_num_reqs
|
||||||
|
|
||||||
if max_total_tokens is not None:
|
if max_total_tokens is not None:
|
||||||
if max_total_tokens > self.max_total_num_tokens:
|
if max_total_tokens > self.max_total_num_tokens:
|
||||||
@@ -681,12 +699,25 @@ class ModelRunner:
|
|||||||
"Not enough memory. Please try to increase --mem-fraction-static."
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
||||||
)
|
)
|
||||||
|
|
||||||
self.req_to_token_pool = ReqToTokenPool(
|
if self.req_to_token_pool is None:
|
||||||
size=max_num_reqs + 1,
|
self.req_to_token_pool = ReqToTokenPool(
|
||||||
max_context_len=self.model_config.context_len + 4,
|
size=max_num_reqs + 1,
|
||||||
device=self.device,
|
max_context_len=self.model_config.context_len + 4,
|
||||||
enable_memory_saver=self.server_args.enable_memory_saver,
|
device=self.device,
|
||||||
)
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Draft worker shares req_to_token_pool with the target worker.
|
||||||
|
assert self.is_draft_worker
|
||||||
|
|
||||||
|
if self.token_to_kv_pool_allocator is None:
|
||||||
|
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
||||||
|
self.max_total_num_tokens,
|
||||||
|
dtype=self.kv_cache_dtype,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
assert self.is_draft_worker
|
||||||
|
|
||||||
if (
|
if (
|
||||||
self.model_config.attention_arch == AttentionArch.MLA
|
self.model_config.attention_arch == AttentionArch.MLA
|
||||||
|
|||||||
@@ -280,11 +280,16 @@ class ServerArgs:
|
|||||||
self.disable_overlap_schedule = True
|
self.disable_overlap_schedule = True
|
||||||
self.prefill_only_one_req = True
|
self.prefill_only_one_req = True
|
||||||
self.disable_cuda_graph_padding = True
|
self.disable_cuda_graph_padding = True
|
||||||
self.disable_radix_cache = True
|
if self.max_running_requests is None:
|
||||||
self.chunked_prefill_size = -1
|
self.max_running_requests = 32
|
||||||
logger.info(
|
logger.info(
|
||||||
f"The radix cache, chunked prefill, and overlap scheduler are disabled because of using {self.speculative_algorithm} speculative decoding."
|
"Overlap scheduler are disabled because of using "
|
||||||
|
"eagle speculative decoding."
|
||||||
|
"Max running request set to 32 because of using eagle speculative decoding."
|
||||||
)
|
)
|
||||||
|
# The token generated from the verify step is counted.
|
||||||
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
||||||
|
assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
||||||
|
|
||||||
# GGUF
|
# GGUF
|
||||||
if (
|
if (
|
||||||
|
|||||||
@@ -3,14 +3,8 @@
|
|||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
|
||||||
|
|
||||||
if is_cuda_available():
|
|
||||||
from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
|
|
||||||
from sgl_kernel import (
|
|
||||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def build_tree_kernel_efficient_preprocess(
|
def build_tree_kernel_efficient_preprocess(
|
||||||
|
|||||||
@@ -21,7 +21,6 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
||||||
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
from sglang.srt.speculative.eagle_worker import EAGLEWorker
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,16 +1,17 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, Dict, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import triton
|
import triton
|
||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
|
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
||||||
create_flashinfer_kv_indices_triton,
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||||
)
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||||
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import (
|
from sglang.srt.speculative.build_eagle_tree import (
|
||||||
build_tree_kernel,
|
build_tree_kernel,
|
||||||
@@ -25,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclass
|
||||||
class EagleDraftInput:
|
class EagleDraftInput:
|
||||||
# The inputs for decode
|
# The inputs for decode
|
||||||
# shape: (b, topk)
|
# shape: (b, topk)
|
||||||
@@ -46,57 +47,46 @@ class EagleDraftInput:
|
|||||||
kv_indptr: torch.Tensor = None
|
kv_indptr: torch.Tensor = None
|
||||||
kv_indices: torch.Tensor = None
|
kv_indices: torch.Tensor = None
|
||||||
|
|
||||||
|
# indices of unfinished requests during extend-after-decode
|
||||||
|
# e.g. [0, 2, 3, 4] if only the 1st request is finished
|
||||||
|
keep_indices: List[int] = None
|
||||||
|
|
||||||
def prepare_for_extend(self, batch: ScheduleBatch):
|
def prepare_for_extend(self, batch: ScheduleBatch):
|
||||||
req_pool_indices = batch.alloc_req_slots(len(batch.reqs))
|
assert batch.input_ids.numel() == batch.out_cache_loc.shape[0]
|
||||||
out_cache_loc = batch.alloc_token_slots(batch.input_ids.numel())
|
# Prefill only generate 1 token.
|
||||||
batch.out_cache_loc = out_cache_loc
|
assert len(self.verified_id) == len(batch.seq_lens)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
for i, req in enumerate(batch.reqs):
|
for i, extend_len in enumerate(batch.extend_lens):
|
||||||
req.req_pool_idx = req_pool_indices[i]
|
input_ids = batch.input_ids[pt : pt + extend_len]
|
||||||
pre_len, seq_len = len(req.prefix_indices), len(req.fill_ids)
|
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
||||||
assert seq_len - pre_len == req.extend_input_len
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
||||||
|
|
||||||
if pre_len > 0:
|
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
:pre_len
|
|
||||||
] = req.prefix_indices
|
|
||||||
|
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
|
||||||
out_cache_loc[pt : pt + req.extend_input_len]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
pt += req.extend_input_len
|
|
||||||
|
|
||||||
# TODO: support batching inputs
|
|
||||||
assert len(batch.extend_lens) == 1
|
|
||||||
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
|
||||||
|
|
||||||
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
||||||
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
||||||
accept_length_cpu = batch.spec_info.accept_length_cpu
|
accept_length_cpu = batch.spec_info.accept_length_cpu
|
||||||
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
||||||
|
batch.extend_num_tokens = sum(batch.extend_lens)
|
||||||
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
||||||
batch.req_pool_indices = batch.spec_info.req_pool_indices_for_draft_extend
|
|
||||||
seq_lens_cpu = batch.seq_lens.tolist()
|
seq_lens_cpu = batch.seq_lens.tolist()
|
||||||
|
assert len(batch.req_pool_indices) == len(batch.reqs)
|
||||||
|
|
||||||
pt = 0
|
pt = 0
|
||||||
i = 0
|
i = 0
|
||||||
for req in batch.reqs:
|
self.keep_indices = []
|
||||||
|
for idx, req in enumerate(batch.reqs):
|
||||||
if req.finished():
|
if req.finished():
|
||||||
continue
|
continue
|
||||||
|
self.keep_indices.append(idx)
|
||||||
# assert seq_len - pre_len == req.extend_input_len
|
# assert seq_len - pre_len == req.extend_input_len
|
||||||
input_len = batch.extend_lens[i]
|
input_len = batch.extend_lens[i]
|
||||||
seq_len = seq_lens_cpu[i]
|
seq_len = seq_lens_cpu[i]
|
||||||
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
|
||||||
seq_len - input_len : seq_len
|
|
||||||
] = batch.out_cache_loc[pt : pt + input_len]
|
|
||||||
pt += input_len
|
pt += input_len
|
||||||
i += 1
|
i += 1
|
||||||
assert pt == batch.out_cache_loc.shape[0]
|
|
||||||
|
|
||||||
self.positions = torch.empty_like(self.verified_id)
|
self.positions = torch.empty_like(self.verified_id, dtype=torch.long)
|
||||||
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
|
||||||
self.accept_length.add_(1)
|
self.accept_length.add_(1)
|
||||||
|
|
||||||
create_extend_spec_info[(self.accept_length.numel(),)](
|
create_extend_spec_info[(self.accept_length.numel(),)](
|
||||||
@@ -117,14 +107,22 @@ class EagleDraftInput:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
paged_kernel_lens_sum: int,
|
||||||
req_to_token: torch.Tensor,
|
req_to_token: torch.Tensor,
|
||||||
):
|
):
|
||||||
bs = self.accept_length.numel()
|
bs = self.accept_length.numel()
|
||||||
|
keep_indices = torch.tensor(self.keep_indices, device=req_pool_indices.device)
|
||||||
|
req_pool_indices = req_pool_indices[keep_indices]
|
||||||
|
assert req_pool_indices.shape[0] == bs
|
||||||
|
assert req_pool_indices.shape[0] == paged_kernel_lens.shape[0]
|
||||||
|
|
||||||
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
qo_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
qo_indptr[1:] = torch.cumsum(self.accept_length, dim=0)
|
||||||
|
|
||||||
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
cum_kv_seq_len = torch.zeros((bs + 1,), dtype=torch.int32, device="cuda")
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
|
# TODO: replace cum_kv_seq_len[-1] with paged_kernel_lens_sum to avoid the device sync.
|
||||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(bs,)](
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
@@ -162,7 +160,21 @@ class EagleDraftInput:
|
|||||||
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
self.topk_index = torch.cat([self.topk_index, spec_info.topk_index])
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclass
|
||||||
|
class EagleVerifyOutput:
|
||||||
|
# Draft input batch
|
||||||
|
draft_input: EagleDraftInput
|
||||||
|
# Logit outputs from target worker
|
||||||
|
logits_output: LogitsProcessorOutput
|
||||||
|
# Accepeted token ids including the bonus token
|
||||||
|
verified_id: torch.Tensor
|
||||||
|
# Accepeted token length per sequence in a batch in CPU.
|
||||||
|
accept_length_per_req_cpu: List[int]
|
||||||
|
# Accepeted indices from logits_output.next_token_logits
|
||||||
|
accepeted_indices_cpu: List[int]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
class EagleVerifyInput:
|
class EagleVerifyInput:
|
||||||
draft_token: torch.Tensor
|
draft_token: torch.Tensor
|
||||||
custom_mask: torch.Tensor
|
custom_mask: torch.Tensor
|
||||||
@@ -267,6 +279,7 @@ class EagleVerifyInput:
|
|||||||
self,
|
self,
|
||||||
req_pool_indices: torch.Tensor,
|
req_pool_indices: torch.Tensor,
|
||||||
paged_kernel_lens: torch.Tensor,
|
paged_kernel_lens: torch.Tensor,
|
||||||
|
paged_kernel_lens_sum: int,
|
||||||
req_to_token: torch.Tensor,
|
req_to_token: torch.Tensor,
|
||||||
):
|
):
|
||||||
batch_size = len(req_pool_indices)
|
batch_size = len(req_pool_indices)
|
||||||
@@ -285,7 +298,11 @@ class EagleVerifyInput:
|
|||||||
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
paged_kernel_lens = paged_kernel_lens + self.draft_token_num
|
||||||
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
cum_kv_seq_len[1:] = torch.cumsum(paged_kernel_lens, dim=0)
|
||||||
|
|
||||||
kv_indices = torch.empty(cum_kv_seq_len[-1], dtype=torch.int32, device="cuda")
|
kv_indices = torch.empty(
|
||||||
|
paged_kernel_lens_sum + self.draft_token_num * batch_size,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device="cuda",
|
||||||
|
)
|
||||||
|
|
||||||
create_flashinfer_kv_indices_triton[(batch_size,)](
|
create_flashinfer_kv_indices_triton[(batch_size,)](
|
||||||
req_to_token,
|
req_to_token,
|
||||||
@@ -298,7 +315,21 @@ class EagleVerifyInput:
|
|||||||
)
|
)
|
||||||
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
return kv_indices, cum_kv_seq_len, qo_indptr, self.custom_mask
|
||||||
|
|
||||||
def verify(self, batch: ScheduleBatch, logits_output: torch.Tensor) -> torch.Tensor:
|
def verify(
|
||||||
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
logits_output: torch.Tensor,
|
||||||
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""WARNING: This API in-place modifies the states of logits_output
|
||||||
|
|
||||||
|
Verify and find accepted tokens based on logits output and batch
|
||||||
|
(which contains spec decoding information).
|
||||||
|
|
||||||
|
This API updates values inside logits_output based on the accepted
|
||||||
|
tokens. I.e., logits_output.next_token_logits only contains
|
||||||
|
accepeted token logits.
|
||||||
|
"""
|
||||||
draft_token = torch.cat(
|
draft_token = torch.cat(
|
||||||
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
[self.draft_token, torch.full([1], -1, dtype=torch.int32, device="cuda")],
|
||||||
dim=-1,
|
dim=-1,
|
||||||
@@ -367,7 +398,6 @@ class EagleVerifyInput:
|
|||||||
|
|
||||||
new_accept_index = []
|
new_accept_index = []
|
||||||
unfinished_index = []
|
unfinished_index = []
|
||||||
finished_extend_len = {} # {rid:accept_length + 1}
|
|
||||||
accept_index_cpu = accept_index.tolist()
|
accept_index_cpu = accept_index.tolist()
|
||||||
predict_cpu = predict.tolist()
|
predict_cpu = predict.tolist()
|
||||||
has_finished = False
|
has_finished = False
|
||||||
@@ -382,7 +412,6 @@ class EagleVerifyInput:
|
|||||||
id = predict_cpu[idx]
|
id = predict_cpu[idx]
|
||||||
# if not found_finished:
|
# if not found_finished:
|
||||||
req.output_ids.append(id)
|
req.output_ids.append(id)
|
||||||
finished_extend_len[req.rid] = j + 1
|
|
||||||
req.check_finished()
|
req.check_finished()
|
||||||
if req.finished():
|
if req.finished():
|
||||||
has_finished = True
|
has_finished = True
|
||||||
@@ -400,11 +429,10 @@ class EagleVerifyInput:
|
|||||||
accept_index = accept_index[accept_index != -1]
|
accept_index = accept_index[accept_index != -1]
|
||||||
accept_length_cpu = accept_length.tolist()
|
accept_length_cpu = accept_length.tolist()
|
||||||
verified_id = predict[accept_index]
|
verified_id = predict[accept_index]
|
||||||
|
|
||||||
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
||||||
evict_mask[accept_index] = False
|
evict_mask[accept_index] = False
|
||||||
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
mem_need_free_idx = batch.out_cache_loc[evict_mask]
|
||||||
batch.token_to_kv_pool.free(mem_need_free_idx)
|
token_to_kv_pool_allocator.free(mem_need_free_idx)
|
||||||
assign_req_to_token_pool[(bs,)](
|
assign_req_to_token_pool[(bs,)](
|
||||||
batch.req_pool_indices,
|
batch.req_pool_indices,
|
||||||
batch.req_to_token_pool.req_to_token,
|
batch.req_to_token_pool.req_to_token,
|
||||||
@@ -427,20 +455,16 @@ class EagleVerifyInput:
|
|||||||
]
|
]
|
||||||
if has_finished:
|
if has_finished:
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
||||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices[
|
|
||||||
unfinished_index
|
|
||||||
]
|
|
||||||
else:
|
else:
|
||||||
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
||||||
draft_input.req_pool_indices_for_draft_extend = batch.req_pool_indices
|
batch.out_cache_loc = batch.out_cache_loc[new_accept_index]
|
||||||
|
|
||||||
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
return EagleVerifyOutput(
|
||||||
return (
|
draft_input=draft_input,
|
||||||
draft_input,
|
logits_output=logits_output,
|
||||||
logits_output,
|
verified_id=verified_id,
|
||||||
verified_id,
|
accept_length_per_req_cpu=accept_length_cpu,
|
||||||
finished_extend_len,
|
accepeted_indices_cpu=accept_index,
|
||||||
accept_length_cpu,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -456,6 +480,18 @@ def eagle_verify_retrive(
|
|||||||
draft_token_num: tl.constexpr,
|
draft_token_num: tl.constexpr,
|
||||||
max_len_upper: tl.constexpr,
|
max_len_upper: tl.constexpr,
|
||||||
):
|
):
|
||||||
|
"""
|
||||||
|
Args:
|
||||||
|
retrive_index: Pointer to indices of draft tokens
|
||||||
|
accept_mask: Mask indicating which tokens were accepted
|
||||||
|
retrive_cum_len: Cumulative lengths of token sequences in a batch
|
||||||
|
accept_index (out): Accept token indices
|
||||||
|
accept_length (out): Length of accepted tokens per sequence in a batch
|
||||||
|
extract_index (out): Index for last accepted tokens
|
||||||
|
max_len: Maximum length in a batch
|
||||||
|
draft_token_num: Number of tokens speculatively generated
|
||||||
|
max_len_upper An upper bound for token sequence length
|
||||||
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
|
|
||||||
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
retrive_end = tl.load(retrive_cum_len + pid + 1)
|
||||||
@@ -649,7 +685,7 @@ def generate_draft_decode_kv_indices(
|
|||||||
tl.store(kv_indptr + zid, base + zid * iters)
|
tl.store(kv_indptr + zid, base + zid * iters)
|
||||||
|
|
||||||
|
|
||||||
@torch.compile
|
@torch.compile(dynamic=True)
|
||||||
def select_top_k_tokens(
|
def select_top_k_tokens(
|
||||||
i: int,
|
i: int,
|
||||||
topk_p: torch.Tensor,
|
topk_p: torch.Tensor,
|
||||||
@@ -671,13 +707,11 @@ def select_top_k_tokens(
|
|||||||
.unsqueeze(0)
|
.unsqueeze(0)
|
||||||
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
.repeat(topk_p.shape[0], 1), # shape: (b, topk + 1)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# The later decode steps
|
# The later decode steps
|
||||||
expand_scores = torch.mul(
|
expand_scores = torch.mul(
|
||||||
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
scores.unsqueeze(2), topk_p.reshape(-1, topk, topk)
|
||||||
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk)
|
||||||
|
|
||||||
topk_cs_p, topk_cs_index = fast_topk(
|
topk_cs_p, topk_cs_index = fast_topk(
|
||||||
expand_scores.flatten(start_dim=1), topk, dim=-1
|
expand_scores.flatten(start_dim=1), topk, dim=-1
|
||||||
) # (b, topk)
|
) # (b, topk)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from typing import List, Optional, Union
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
@@ -22,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|||||||
from sglang.srt.speculative.eagle_utils import (
|
from sglang.srt.speculative.eagle_utils import (
|
||||||
EagleDraftInput,
|
EagleDraftInput,
|
||||||
EagleVerifyInput,
|
EagleVerifyInput,
|
||||||
|
EagleVerifyOutput,
|
||||||
assign_draft_cache_locs,
|
assign_draft_cache_locs,
|
||||||
fast_topk,
|
fast_topk,
|
||||||
select_top_k_tokens,
|
select_top_k_tokens,
|
||||||
)
|
)
|
||||||
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
||||||
|
from sglang.srt.utils import get_available_gpu_memory
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -42,12 +44,16 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
nccl_port: int,
|
nccl_port: int,
|
||||||
target_worker: TpModelWorker,
|
target_worker: TpModelWorker,
|
||||||
):
|
):
|
||||||
|
# Override context length with target model's context length
|
||||||
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
||||||
|
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
||||||
|
|
||||||
# Do not capture cuda graph in `super().__init__()`
|
# Do not capture cuda graph in `super().__init__()`
|
||||||
# We will capture it later
|
# We will capture it later
|
||||||
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
||||||
server_args.disable_cuda_graph = True
|
server_args.disable_cuda_graph = True
|
||||||
|
|
||||||
# Load hot token ids
|
# Lossy optimization by using hot tokens
|
||||||
if server_args.speculative_token_map is not None:
|
if server_args.speculative_token_map is not None:
|
||||||
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
||||||
server_args.json_model_override_args = (
|
server_args.json_model_override_args = (
|
||||||
@@ -56,6 +62,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
self.hot_token_id = None
|
self.hot_token_id = None
|
||||||
|
|
||||||
|
# We share the allocator with a target worker. Draft/target worker
|
||||||
|
# owns its own KV cache.
|
||||||
|
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
||||||
|
target_worker.get_memory_pool()
|
||||||
|
)
|
||||||
|
|
||||||
# Init target worker
|
# Init target worker
|
||||||
super().__init__(
|
super().__init__(
|
||||||
gpu_id=gpu_id,
|
gpu_id=gpu_id,
|
||||||
@@ -64,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
nccl_port=nccl_port,
|
nccl_port=nccl_port,
|
||||||
dp_rank=dp_rank,
|
dp_rank=dp_rank,
|
||||||
is_draft_worker=True,
|
is_draft_worker=True,
|
||||||
|
req_to_token_pool=self.req_to_token_pool,
|
||||||
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
||||||
)
|
)
|
||||||
self.target_worker = target_worker
|
self.target_worker = target_worker
|
||||||
self.finish_extend_len = []
|
|
||||||
|
|
||||||
# Parse arguments
|
# Parse arguments
|
||||||
self.topk = server_args.speculative_eagle_topk
|
self.topk = server_args.speculative_eagle_topk
|
||||||
@@ -75,6 +88,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
server_args.speculative_algorithm
|
server_args.speculative_algorithm
|
||||||
)
|
)
|
||||||
self.server_args = server_args
|
self.server_args = server_args
|
||||||
|
self.use_nan_detection = self.server_args.enable_nan_detection
|
||||||
|
self.device = self.model_runner.device
|
||||||
|
self.gpu_id = self.model_runner.gpu_id
|
||||||
|
|
||||||
# Share the embedding and lm_head
|
# Share the embedding and lm_head
|
||||||
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
||||||
@@ -82,8 +98,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
head = head.clone()
|
head = head.clone()
|
||||||
self.hot_token_id = self.hot_token_id.to(head.device)
|
self.hot_token_id = self.hot_token_id.to(head.device)
|
||||||
head.data = head.data[self.hot_token_id]
|
head.data = head.data[self.hot_token_id]
|
||||||
self.model_runner.model.set_embed_and_head(embed, head)
|
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
||||||
self.model_runner.server_args.disable_cuda_graph = backup_disable_cuda_graph
|
self.draft_model_runner.server_args.disable_cuda_graph = (
|
||||||
|
backup_disable_cuda_graph
|
||||||
|
)
|
||||||
|
|
||||||
# Create multi-step attn backends and cuda graph runners
|
# Create multi-step attn backends and cuda graph runners
|
||||||
if server_args.attention_backend == "flashinfer":
|
if server_args.attention_backend == "flashinfer":
|
||||||
@@ -111,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
self.init_cuda_graphs()
|
self.init_cuda_graphs()
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
@@ -122,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
return
|
return
|
||||||
|
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
logger.info(
|
||||||
|
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
||||||
logger.info(f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s")
|
logger.info(
|
||||||
|
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
||||||
|
)
|
||||||
|
|
||||||
def forward_batch_speculative_generation(self, batch: ScheduleBatch):
|
@property
|
||||||
|
def draft_model_runner(self):
|
||||||
|
return self.model_runner
|
||||||
|
|
||||||
|
def forward_batch_speculative_generation(
|
||||||
|
self, batch: ScheduleBatch
|
||||||
|
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
||||||
|
"""Run speculative decoding forward.
|
||||||
|
|
||||||
|
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
||||||
|
the final output batch doesn't have the same state as the input.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to run forward. The state of the batch is modified as it runs.
|
||||||
|
Returns:
|
||||||
|
A tuple of the final logit output of the target model, next tokens accepeted,
|
||||||
|
the batch id (used for overlap schedule), and number of accepeted tokens.
|
||||||
|
"""
|
||||||
|
assert not batch.spec_algorithm.is_none()
|
||||||
if batch.forward_mode.is_decode():
|
if batch.forward_mode.is_decode():
|
||||||
# Draft
|
spec_info, to_free_cache_loc = self.draft(batch)
|
||||||
spec_info: EagleVerifyInput = self.draft(batch)
|
logits_output, verify_output, model_worker_batch = self.verify(
|
||||||
|
batch, spec_info
|
||||||
# Verify
|
)
|
||||||
(
|
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
||||||
next_draft_input,
|
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
||||||
logits_output,
|
# if it is None, means all requests are finished
|
||||||
verified_id,
|
|
||||||
self.finish_extend_len,
|
|
||||||
accept_length_cpu,
|
|
||||||
model_worker_batch,
|
|
||||||
) = self.verify(batch, spec_info)
|
|
||||||
batch.spec_info = next_draft_input
|
|
||||||
# if it is None, means all requsets are finished
|
|
||||||
if batch.spec_info.verified_id is not None:
|
if batch.spec_info.verified_id is not None:
|
||||||
self.forward_draft_extend_after_decode(batch)
|
self.forward_draft_extend_after_decode(batch)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
logits_output,
|
logits_output,
|
||||||
verified_id,
|
verify_output.verified_id,
|
||||||
model_worker_batch,
|
model_worker_batch.bid,
|
||||||
sum(accept_length_cpu),
|
sum(verify_output.accept_length_per_req_cpu),
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Forward with the target model and get hidden states.
|
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
||||||
# We need the full hidden states to prefill the KV cache of the draft model.
|
self.forward_draft_extend(
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
batch, logits_output.hidden_states, next_token_ids
|
||||||
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
||||||
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
|
||||||
model_worker_batch
|
|
||||||
)
|
)
|
||||||
|
return logits_output, next_token_ids, bid, 0
|
||||||
|
|
||||||
# Forward with the draft model.
|
def forward_target_extend(
|
||||||
batch.spec_info = EagleDraftInput(
|
self, batch: ScheduleBatch
|
||||||
hidden_states=logits_output.hidden_states,
|
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
||||||
verified_id=next_token_ids,
|
"""Run the target extend.
|
||||||
)
|
|
||||||
self.forward_draft_extend(batch)
|
Args:
|
||||||
return logits_output, next_token_ids, model_worker_batch, 0
|
batch: The batch to run. States could be modified.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
logits_output: The output of logits. It will contain the full hidden states.
|
||||||
|
next_token_ids: Next token ids generated.
|
||||||
|
bid: The model batch ID. Used for overlap schedule.
|
||||||
|
"""
|
||||||
|
# Forward with the target model and get hidden states.
|
||||||
|
# We need the full hidden states to prefill the KV cache of the draft model.
|
||||||
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
|
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
||||||
|
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
||||||
|
model_worker_batch
|
||||||
|
)
|
||||||
|
return logits_output, next_token_ids, model_worker_batch.bid
|
||||||
|
|
||||||
def draft(self, batch: ScheduleBatch):
|
def draft(self, batch: ScheduleBatch):
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
|
||||||
|
|
||||||
# Parse args
|
# Parse args
|
||||||
num_seqs = batch.batch_size()
|
num_seqs = batch.batch_size()
|
||||||
spec_info = batch.spec_info
|
spec_info = batch.spec_info
|
||||||
@@ -188,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
|
|
||||||
batch.out_cache_loc = out_cache_loc
|
batch.out_cache_loc = out_cache_loc
|
||||||
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
||||||
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
||||||
@@ -196,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
# Get forward batch
|
# Get forward batch
|
||||||
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(
|
||||||
|
model_worker_batch, self.draft_model_runner
|
||||||
|
)
|
||||||
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
||||||
forward_batch
|
forward_batch
|
||||||
)
|
)
|
||||||
|
|
||||||
if can_cuda_graph:
|
if can_cuda_graph:
|
||||||
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
||||||
forward_batch
|
forward_batch
|
||||||
@@ -208,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
else:
|
else:
|
||||||
# Initialize attention backend
|
# Initialize attention backend
|
||||||
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
||||||
|
forward_batch = ForwardBatch.init_new(
|
||||||
|
model_worker_batch, self.draft_model_runner
|
||||||
|
)
|
||||||
# Run forward steps
|
# Run forward steps
|
||||||
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
||||||
|
|
||||||
@@ -225,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
batch.sampling_info.is_all_greedy,
|
batch.sampling_info.is_all_greedy,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Free cache locations
|
return ret, out_cache_loc
|
||||||
batch.token_to_kv_pool.free(out_cache_loc)
|
|
||||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
|
||||||
return ret
|
|
||||||
|
|
||||||
def draft_forward(self, forward_batch: ForwardBatch):
|
def draft_forward(self, forward_batch: ForwardBatch):
|
||||||
# Parse args
|
# Parse args
|
||||||
@@ -278,6 +321,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output = self.model_runner.model.forward(
|
logits_output = self.model_runner.model.forward(
|
||||||
forward_batch.input_ids, forward_batch.positions, forward_batch
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
||||||
)
|
)
|
||||||
|
self._detect_nan_if_needed(logits_output)
|
||||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
if self.hot_token_id is not None:
|
if self.hot_token_id is not None:
|
||||||
@@ -294,71 +338,88 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
logits_output, _ = self.target_worker.forward_batch_generation(
|
logits_output, _ = self.target_worker.forward_batch_generation(
|
||||||
model_worker_batch, skip_sample=True
|
model_worker_batch, skip_sample=True
|
||||||
)
|
)
|
||||||
|
self._detect_nan_if_needed(logits_output)
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
spec_info.hidden_states = logits_output.hidden_states
|
||||||
res = spec_info.verify(batch, logits_output)
|
res: EagleVerifyOutput = spec_info.verify(
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch, logits_output, self.token_to_kv_pool_allocator
|
||||||
return res + (model_worker_batch,)
|
)
|
||||||
|
|
||||||
def forward_draft_extend(self, batch: ScheduleBatch):
|
# Post process based on verified outputs.
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
# Pick indices that we care (accepeted)
|
||||||
|
logits_output.next_token_logits = logits_output.next_token_logits[
|
||||||
|
res.accepeted_indices_cpu
|
||||||
|
]
|
||||||
|
logits_output.hidden_states = logits_output.hidden_states[
|
||||||
|
res.accepeted_indices_cpu
|
||||||
|
]
|
||||||
|
# Prepare the batch for the next draft forwards.
|
||||||
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
|
batch.spec_info = res.draft_input
|
||||||
|
|
||||||
|
return logits_output, res, model_worker_batch
|
||||||
|
|
||||||
|
def forward_draft_extend(
|
||||||
|
self,
|
||||||
|
batch: ScheduleBatch,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
next_token_ids: List[int],
|
||||||
|
):
|
||||||
|
"""Run draft model extend. This API modifies the states of the batch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
batch: The batch to run.
|
||||||
|
hidden_states: Hidden states from the target model forward
|
||||||
|
next_token_ids: Next token ids generated from the target forward.
|
||||||
|
"""
|
||||||
|
batch.spec_info = EagleDraftInput(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
verified_id=next_token_ids,
|
||||||
|
)
|
||||||
batch.spec_info.prepare_for_extend(batch)
|
batch.spec_info.prepare_for_extend(batch)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
model_worker_batch, self.draft_model_runner
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
)
|
||||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||||
|
self._detect_nan_if_needed(logits_output)
|
||||||
def _set_mem_pool(self, batch: ScheduleBatch, runner: ModelRunner):
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
||||||
batch.token_to_kv_pool = runner.token_to_kv_pool
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
batch.req_to_token_pool = runner.req_to_token_pool
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
||||||
seq_lens_backup = batch.seq_lens
|
seq_lens_backup = batch.seq_lens
|
||||||
req_pool_indices_backup = batch.req_pool_indices
|
|
||||||
|
|
||||||
self._set_mem_pool(batch, self.model_runner)
|
|
||||||
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
||||||
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
||||||
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
||||||
|
# We don't need logprob for this extend.
|
||||||
model_worker_batch = batch.get_model_worker_batch()
|
model_worker_batch = batch.get_model_worker_batch()
|
||||||
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
forward_batch = ForwardBatch.init_new(
|
||||||
logits_output = self.model_runner.forward(forward_batch)
|
model_worker_batch, self.draft_model_runner
|
||||||
self.capture_for_decode(logits_output, forward_batch)
|
)
|
||||||
self._set_mem_pool(batch, self.target_worker.model_runner)
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
||||||
|
self._detect_nan_if_needed(logits_output)
|
||||||
|
assert forward_batch.spec_info is batch.spec_info
|
||||||
|
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
||||||
|
|
||||||
# Restore backup.
|
# Restore backup.
|
||||||
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
||||||
batch.forward_mode = ForwardMode.DECODE
|
batch.forward_mode = ForwardMode.DECODE
|
||||||
batch.seq_lens = seq_lens_backup
|
batch.seq_lens = seq_lens_backup
|
||||||
batch.req_pool_indices = req_pool_indices_backup
|
|
||||||
|
|
||||||
def capture_for_decode(
|
def capture_for_decode(
|
||||||
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
||||||
):
|
):
|
||||||
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
||||||
spec_info = forward_batch.spec_info
|
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
||||||
spec_info.topk_p, spec_info.topk_index = fast_topk(probs, self.topk, dim=-1)
|
draft_input.hidden_states = logits_output.hidden_states
|
||||||
spec_info.hidden_states = logits_output.hidden_states
|
|
||||||
|
|
||||||
# Don't support prefix share now.
|
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
||||||
def finish_request(self, reqs: Union[Req, List[Req]]):
|
if self.use_nan_detection:
|
||||||
if not isinstance(reqs, List):
|
logits = logits_output.next_token_logits
|
||||||
reqs = [reqs]
|
if torch.any(torch.isnan(logits)):
|
||||||
for req in reqs:
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
||||||
if req.rid not in self.finish_extend_len:
|
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
||||||
continue
|
|
||||||
req_len = (
|
|
||||||
len(req.origin_input_ids)
|
|
||||||
+ len(req.output_ids)
|
|
||||||
- self.finish_extend_len[req.rid]
|
|
||||||
- 1
|
|
||||||
)
|
|
||||||
kv_indices = self.model_runner.req_to_token_pool.req_to_token[
|
|
||||||
req.req_pool_idx
|
|
||||||
][:req_len]
|
|
||||||
self.model_runner.token_to_kv_pool.free(kv_indices)
|
|
||||||
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
|
||||||
|
|
||||||
|
|
||||||
def load_token_map(token_map_path: str) -> List[int]:
|
def load_token_map(token_map_path: str) -> List[int]:
|
||||||
|
|||||||
@@ -20,7 +20,3 @@ class SpeculativeAlgorithm(IntEnum):
|
|||||||
if name is not None:
|
if name is not None:
|
||||||
name = name.upper()
|
name = name.upper()
|
||||||
return name_map[name]
|
return name_map[name]
|
||||||
|
|
||||||
|
|
||||||
class SpecInfo:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -1,3 +1,4 @@
|
|||||||
|
import multiprocessing as mp
|
||||||
import random
|
import random
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
@@ -18,6 +19,8 @@ from sglang.test.test_utils import (
|
|||||||
popen_launch_server,
|
popen_launch_server,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
acc_rate_tolerance = 0.15
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLEEngine(unittest.TestCase):
|
class TestEAGLEEngine(unittest.TestCase):
|
||||||
BASE_CONFIG = {
|
BASE_CONFIG = {
|
||||||
@@ -43,13 +46,19 @@ class TestEAGLEEngine(unittest.TestCase):
|
|||||||
configs = [
|
configs = [
|
||||||
self.BASE_CONFIG,
|
self.BASE_CONFIG,
|
||||||
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
{**self.BASE_CONFIG, "disable_cuda_graph": True},
|
||||||
|
{**self.BASE_CONFIG, "chunked_prefill_size": 2},
|
||||||
]
|
]
|
||||||
|
|
||||||
for config in configs:
|
for config in configs:
|
||||||
with self.subTest(
|
with self.subTest(
|
||||||
cuda_graph=(
|
cuda_graph=(
|
||||||
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
|
"enabled" if len(config) == len(self.BASE_CONFIG) else "disabled"
|
||||||
)
|
),
|
||||||
|
chunked_prefill_size=(
|
||||||
|
config["chunked_prefill_size"]
|
||||||
|
if "chunked_prefill_size" in config
|
||||||
|
else "default"
|
||||||
|
),
|
||||||
):
|
):
|
||||||
engine = sgl.Engine(**config)
|
engine = sgl.Engine(**config)
|
||||||
try:
|
try:
|
||||||
@@ -125,6 +134,8 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
"64",
|
"64",
|
||||||
"--mem-fraction-static",
|
"--mem-fraction-static",
|
||||||
"0.7",
|
"0.7",
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
"32",
|
"32",
|
||||||
],
|
],
|
||||||
@@ -196,6 +207,137 @@ class TestEAGLEServer(unittest.TestCase):
|
|||||||
self.assertGreater(metrics["accuracy"], 0.20)
|
self.assertGreater(metrics["accuracy"], 0.20)
|
||||||
|
|
||||||
|
|
||||||
|
def measure_acc_rate(engine):
|
||||||
|
tic = time.time()
|
||||||
|
prompt = [
|
||||||
|
"Human: Give me a fully functional FastAPI server. Show the python code.<|separator|>\n\nAssistant:"
|
||||||
|
]
|
||||||
|
sampling_params = {"temperature": 0, "max_new_tokens": 512}
|
||||||
|
output = engine.generate(prompt, sampling_params)
|
||||||
|
output = output[0]
|
||||||
|
latency = time.time() - tic
|
||||||
|
|
||||||
|
if "spec_verify_ct" in output["meta_info"]:
|
||||||
|
base_acc_length = (
|
||||||
|
output["meta_info"]["completion_tokens"]
|
||||||
|
/ output["meta_info"]["spec_verify_ct"]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
base_acc_length = 0.0
|
||||||
|
|
||||||
|
base_speed = output["meta_info"]["completion_tokens"] / latency
|
||||||
|
return base_acc_length, base_speed
|
||||||
|
|
||||||
|
|
||||||
|
class TestEagleAcceptanceRate(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
mp.set_start_method("spawn", force=True)
|
||||||
|
ref_engine = sgl.Engine(
|
||||||
|
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
speculative_algorithm="EAGLE",
|
||||||
|
speculative_num_steps=5,
|
||||||
|
speculative_eagle_topk=8,
|
||||||
|
speculative_num_draft_tokens=64,
|
||||||
|
mem_fraction_static=0.7,
|
||||||
|
disable_radix_cache=True,
|
||||||
|
)
|
||||||
|
cls.base_acc_length, cls.base_speed = measure_acc_rate(ref_engine)
|
||||||
|
ref_engine.shutdown()
|
||||||
|
assert cls.base_acc_length > 4.45
|
||||||
|
|
||||||
|
def test_acc_rate(self):
|
||||||
|
base_acc_length, base_speed = self.base_acc_length, self.base_speed
|
||||||
|
chunk_engine = sgl.Engine(
|
||||||
|
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
speculative_algorithm="EAGLE",
|
||||||
|
speculative_num_steps=5,
|
||||||
|
speculative_eagle_topk=8,
|
||||||
|
speculative_num_draft_tokens=64,
|
||||||
|
mem_fraction_static=0.7,
|
||||||
|
chunked_prefill_size=2,
|
||||||
|
disable_radix_cache=True,
|
||||||
|
)
|
||||||
|
chunked_acc_length, chunked_base_speed = measure_acc_rate(chunk_engine)
|
||||||
|
chunk_engine.shutdown()
|
||||||
|
print(base_acc_length, base_speed)
|
||||||
|
print(chunked_acc_length, chunked_base_speed)
|
||||||
|
assert abs(base_acc_length - chunked_acc_length) < acc_rate_tolerance
|
||||||
|
|
||||||
|
def test_acc_rate_prefix_caching(self):
|
||||||
|
base_acc_length, base_speed = self.base_acc_length, self.base_speed
|
||||||
|
prefix_caching_engine = sgl.Engine(
|
||||||
|
model_path=DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
speculative_draft_model_path=DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
speculative_algorithm="EAGLE",
|
||||||
|
speculative_num_steps=5,
|
||||||
|
speculative_eagle_topk=8,
|
||||||
|
speculative_num_draft_tokens=64,
|
||||||
|
mem_fraction_static=0.7,
|
||||||
|
chunked_prefill_size=4,
|
||||||
|
schedule_policy="lpm",
|
||||||
|
)
|
||||||
|
for _ in range(10):
|
||||||
|
acc_length, _ = measure_acc_rate(prefix_caching_engine)
|
||||||
|
print(f"{acc_length=}")
|
||||||
|
assert abs(base_acc_length - acc_length) < acc_rate_tolerance
|
||||||
|
# The second one should hit the prefix cache.
|
||||||
|
prefix_caching_engine.shutdown()
|
||||||
|
|
||||||
|
|
||||||
|
class TestEAGLERetract(unittest.TestCase):
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
|
cls.process = popen_launch_server(
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
|
cls.base_url,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
other_args=[
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft-model-path",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"5",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"8",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"64",
|
||||||
|
"--mem-fraction-static",
|
||||||
|
"0.7",
|
||||||
|
"--chunked-prefill-size",
|
||||||
|
"128",
|
||||||
|
"--max-running-requests",
|
||||||
|
"64",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
kill_process_tree(cls.process.pid)
|
||||||
|
|
||||||
|
def test_gsm8k(self):
|
||||||
|
args = SimpleNamespace(
|
||||||
|
num_shots=5,
|
||||||
|
data_path=None,
|
||||||
|
num_questions=200,
|
||||||
|
max_new_tokens=512,
|
||||||
|
parallel=128,
|
||||||
|
host="http://127.0.0.1",
|
||||||
|
port=int(self.base_url.split(":")[-1]),
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
print(f"{metrics=}")
|
||||||
|
|
||||||
|
self.assertGreater(metrics["accuracy"], 0.20)
|
||||||
|
# Wait a little bit so that the memory check happens.
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
|
||||||
class TestEAGLEServerTriton(TestEAGLEServer):
|
class TestEAGLEServerTriton(TestEAGLEServer):
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
|
|||||||
Reference in New Issue
Block a user