Support speculative decoding in hybrid attention backend (#9573)
This commit is contained in:
@@ -5,6 +5,7 @@ import torch
|
|||||||
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
||||||
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
||||||
|
|
||||||
|
|
||||||
@@ -12,19 +13,27 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
"""Support different backends for prefill and decode."""
|
"""Support different backends for prefill and decode."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, prefill_backend: AttentionBackend, decode_backend: AttentionBackend
|
self,
|
||||||
|
model_runner: ModelRunner,
|
||||||
|
prefill_backend: AttentionBackend,
|
||||||
|
decode_backend: AttentionBackend,
|
||||||
):
|
):
|
||||||
|
self.model_runner = model_runner
|
||||||
self.prefill_backend = prefill_backend
|
self.prefill_backend = prefill_backend
|
||||||
self.decode_backend = decode_backend
|
self.decode_backend = decode_backend
|
||||||
|
|
||||||
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
||||||
if forward_batch.forward_mode.is_decode():
|
if forward_batch.forward_mode.is_decode_or_idle():
|
||||||
self.decode_backend.init_forward_metadata(forward_batch)
|
self.decode_backend.init_forward_metadata(forward_batch)
|
||||||
else:
|
else:
|
||||||
self.prefill_backend.init_forward_metadata(forward_batch)
|
self.prefill_backend.init_forward_metadata(forward_batch)
|
||||||
|
|
||||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||||
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
self.decode_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
if self.model_runner.server_args.speculative_algorithm is not None:
|
||||||
|
# When speculative decoding is enabled, we also need to initialize the
|
||||||
|
# prefill backend's cuda graph state to support target_verify.
|
||||||
|
self.prefill_backend.init_cuda_graph_state(max_bs, max_num_tokens)
|
||||||
|
|
||||||
def init_forward_metadata_capture_cuda_graph(
|
def init_forward_metadata_capture_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -36,15 +45,26 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
forward_mode: ForwardMode,
|
forward_mode: ForwardMode,
|
||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
):
|
):
|
||||||
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
if forward_mode.is_decode_or_idle():
|
||||||
bs,
|
self.decode_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
num_tokens,
|
bs,
|
||||||
req_pool_indices,
|
num_tokens,
|
||||||
seq_lens,
|
req_pool_indices,
|
||||||
encoder_lens,
|
seq_lens,
|
||||||
forward_mode,
|
encoder_lens,
|
||||||
spec_info,
|
forward_mode,
|
||||||
)
|
spec_info,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.prefill_backend.init_forward_metadata_capture_cuda_graph(
|
||||||
|
bs,
|
||||||
|
num_tokens,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
)
|
||||||
|
|
||||||
def init_forward_metadata_replay_cuda_graph(
|
def init_forward_metadata_replay_cuda_graph(
|
||||||
self,
|
self,
|
||||||
@@ -57,16 +77,28 @@ class HybridAttnBackend(AttentionBackend):
|
|||||||
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
|
||||||
seq_lens_cpu: Optional[torch.Tensor],
|
seq_lens_cpu: Optional[torch.Tensor],
|
||||||
):
|
):
|
||||||
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
if forward_mode.is_decode_or_idle():
|
||||||
bs,
|
self.decode_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
req_pool_indices,
|
bs,
|
||||||
seq_lens,
|
req_pool_indices,
|
||||||
seq_lens_sum,
|
seq_lens,
|
||||||
encoder_lens,
|
seq_lens_sum,
|
||||||
forward_mode,
|
encoder_lens,
|
||||||
spec_info,
|
forward_mode,
|
||||||
seq_lens_cpu,
|
spec_info,
|
||||||
)
|
seq_lens_cpu,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.prefill_backend.init_forward_metadata_replay_cuda_graph(
|
||||||
|
bs,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
seq_lens_sum,
|
||||||
|
encoder_lens,
|
||||||
|
forward_mode,
|
||||||
|
spec_info,
|
||||||
|
seq_lens_cpu,
|
||||||
|
)
|
||||||
|
|
||||||
def get_cuda_graph_seq_len_fill_value(self):
|
def get_cuda_graph_seq_len_fill_value(self):
|
||||||
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
return self.decode_backend.get_cuda_graph_seq_len_fill_value()
|
||||||
|
|||||||
@@ -1440,14 +1440,12 @@ class ModelRunner:
|
|||||||
else self.server_args.attention_backend
|
else self.server_args.attention_backend
|
||||||
)
|
)
|
||||||
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
||||||
assert (
|
|
||||||
self.server_args.speculative_algorithm is None
|
|
||||||
), "Currently HybridAttentionBackend does not support speculative decoding."
|
|
||||||
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
||||||
HybridAttnBackend,
|
HybridAttnBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
attn_backend = HybridAttnBackend(
|
attn_backend = HybridAttnBackend(
|
||||||
|
self,
|
||||||
decode_backend=self._get_attention_backend_from_str(
|
decode_backend=self._get_attention_backend_from_str(
|
||||||
self.decode_attention_backend_str
|
self.decode_attention_backend_str
|
||||||
),
|
),
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import requests
|
|||||||
from sglang.srt.utils import get_device_sm, kill_process_tree
|
from sglang.srt.utils import get_device_sm, kill_process_tree
|
||||||
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k
|
||||||
from sglang.test.test_utils import (
|
from sglang.test.test_utils import (
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST,
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST,
|
DEFAULT_MODEL_NAME_FOR_TEST,
|
||||||
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
DEFAULT_MODEL_NAME_FOR_TEST_MLA,
|
||||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
@@ -36,7 +38,7 @@ class TestHybridAttnBackendBase(CustomTestCase):
|
|||||||
base_url = DEFAULT_URL_FOR_TEST
|
base_url = DEFAULT_URL_FOR_TEST
|
||||||
accuracy_threshold = 0.65 # derived tests need to override this
|
accuracy_threshold = 0.65 # derived tests need to override this
|
||||||
speculative_decode = False
|
speculative_decode = False
|
||||||
spec_decode_threshold = 1.0 # derived spec decoding tests need to override this
|
spec_decode_threshold = 2.2 # derived spec decoding tests need to override this
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_server_args(cls):
|
def get_server_args(cls):
|
||||||
@@ -49,8 +51,12 @@ class TestHybridAttnBackendBase(CustomTestCase):
|
|||||||
# please don't do this if you want to make your inference workload faster
|
# please don't do this if you want to make your inference workload faster
|
||||||
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
|
os.environ["SGL_JIT_DEEPGEMM_PRECOMPILE"] = "false"
|
||||||
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "false"
|
||||||
|
if cls.speculative_decode:
|
||||||
|
model = DEFAULT_EAGLE_TARGET_MODEL_FOR_TEST
|
||||||
|
else:
|
||||||
|
model = cls.model
|
||||||
cls.process = popen_launch_server(
|
cls.process = popen_launch_server(
|
||||||
cls.model,
|
model,
|
||||||
cls.base_url,
|
cls.base_url,
|
||||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
other_args=cls.get_server_args(),
|
other_args=cls.get_server_args(),
|
||||||
@@ -105,5 +111,26 @@ class TestHybridAttnBackendTorchCompile(TestHybridAttnBackendBase):
|
|||||||
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
|
return DEFAULT_SERVER_ARGS + ["--enable-torch-compile"]
|
||||||
|
|
||||||
|
|
||||||
|
class TestHybridAttnBackendSpeculativeDecoding(TestHybridAttnBackendBase):
|
||||||
|
speculative_decode = True
|
||||||
|
# This eagle test uses a very small model, so the accuracy is low.
|
||||||
|
accuracy_threshold = 0.2
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_server_args(cls):
|
||||||
|
return DEFAULT_SERVER_ARGS + [
|
||||||
|
"--speculative-algorithm",
|
||||||
|
"EAGLE",
|
||||||
|
"--speculative-draft",
|
||||||
|
DEFAULT_EAGLE_DRAFT_MODEL_FOR_TEST,
|
||||||
|
"--speculative-num-steps",
|
||||||
|
"3",
|
||||||
|
"--speculative-eagle-topk",
|
||||||
|
"2",
|
||||||
|
"--speculative-num-draft-tokens",
|
||||||
|
"4",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
Reference in New Issue
Block a user