[Feature] Speculative decoding support lookahead (#9873)
Co-authored-by: a4zhangfei <a4zhangfei@qq.com> Co-authored-by: Qiaolin-Yu <liin1211@outlook.com>
This commit is contained in:
@@ -29,6 +29,7 @@ from sglang.srt.layers.radix_attention import AttentionType
|
||||
from sglang.srt.mem_cache.allocator import SWATokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
from sglang.srt.speculative.lookahead_utils import LookaheadVerifyInput
|
||||
from sglang.srt.utils import (
|
||||
is_flashinfer_available,
|
||||
is_sm100_supported,
|
||||
@@ -317,7 +318,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens: torch.Tensor,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
decode_wrappers = []
|
||||
@@ -422,7 +425,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
seq_lens_sum: int,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
if forward_mode.is_decode_or_idle():
|
||||
@@ -638,7 +643,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -651,7 +658,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
decode_wrappers = decode_wrappers or self.decode_wrappers
|
||||
self.call_begin_forward(
|
||||
@@ -673,7 +682,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
assert self.sliding_window_size is not None
|
||||
for wrapper_id in range(2):
|
||||
@@ -721,7 +732,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
seq_lens_sum: int,
|
||||
decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -753,7 +766,9 @@ class FlashInferIndicesUpdaterDecode:
|
||||
paged_kernel_lens_sum: int,
|
||||
kv_indptr: torch.Tensor,
|
||||
kv_start_idx: torch.Tensor,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
@@ -858,7 +873,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
# Keep the signature for type checking. It will be assigned during runtime.
|
||||
raise NotImplementedError()
|
||||
@@ -873,7 +890,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
if use_ragged:
|
||||
# TODO: remove this device sync, we can use forward_batch.extend_prefix_lens_cpu
|
||||
@@ -909,7 +928,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -955,7 +976,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
prefill_wrappers: List[BatchPrefillWithPagedKVCacheWrapper],
|
||||
use_ragged: bool,
|
||||
encoder_lens: Optional[torch.Tensor],
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
):
|
||||
for wrapper_id in range(2):
|
||||
if wrapper_id == 0:
|
||||
@@ -997,7 +1020,9 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
kv_indptr: torch.Tensor,
|
||||
qo_indptr: torch.Tensor,
|
||||
use_ragged: bool,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
spec_info: Optional[
|
||||
Union[EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput]
|
||||
],
|
||||
use_sliding_window_kv_pool: bool = False,
|
||||
):
|
||||
bs = len(seq_lens)
|
||||
@@ -1024,8 +1049,8 @@ class FlashInferIndicesUpdaterPrefill:
|
||||
qo_indptr = qo_indptr[: bs + 1]
|
||||
custom_mask = None
|
||||
else:
|
||||
assert isinstance(spec_info, EagleDraftInput) or isinstance(
|
||||
spec_info, EagleVerifyInput
|
||||
assert isinstance(
|
||||
spec_info, (EagleDraftInput, EagleVerifyInput, LookaheadVerifyInput)
|
||||
)
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
|
||||
Reference in New Issue
Block a user