Support speculative decoding in hybrid attention backend (#9573)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||
|
||||
|
||||
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
|
||||
"""Support different backends for prefill and decode."""
|
||||
|
||||
def __init__(
|
||||
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
||||
self,
|
||||
model_runner: ModelRunner,
|
||||
prefill_backend: AttentionBackend,
|
||||
decode_backend: AttentionBackend,
|
||||
):
|
||||
self.model_runner = model_runner
|
||||
self.prefill_backend = prefill_backend
|
||||
self.decode_backend = decode_backend
|
||||
|
||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||
if forward_batch.forward_mode.is_decode():
|
||||
if forward_batch.forward_mode.is_decode_or_idle():
|
||||
self.decode_backend.init_forward_metadata(forward_batch)
|
||||
else:
|
||||
self.prefill_backend.init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
if self.model_runner.server_args.speculative_algorithm is not None:
|
||||
# When speculative decoding is enabled, we also need to initialize the
|
||||
# prefill backend's cuda graph state to support target_verify.
|
||||
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
@@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend):
|
||||
forward_mode: ForwardMode,
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
):
|
||||
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
else:
|
||||
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
|
||||
bs,
|
||||
num_tokens,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
self,
|
||||
@@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||
seq_lens_cpu: Optional[torch.Tensor],
|
||||
):
|
||||
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
else:
|
||||
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
|
||||
bs,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
encoder_lens,
|
||||
forward_mode,
|
||||
spec_info,
|
||||
seq_lens_cpu,
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
||||
|
||||
@@ -1440,14 +1440,12 @@ class ModelRunner:
|
||||
else self.server_args.attention_backend
|
||||
)
|
||||
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
||||
assert (
|
||||
self.server_args.speculative_algorithm is None
|
||||
), "Currently HybridAttentionBackend does not support speculative decoding."
|
||||
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
||||
HybridAttnBackend,
|
||||
)
|
||||
|
||||
attn_backend = HybridAttnBackend(
|
||||
self,
|
||||
decode_backend=self._get_attention_backend_from_str(
|
||||
self.decode_attention_backend_str
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user