diff --git a/python/sglang/srt/layers/attention/aiter_backend.py b/python/sglang/srt/layers/attention/aiter_backend.py index 30901805d..dafe5ee19 100644 --- a/python/sglang/srt/layers/attention/aiter_backend.py +++ b/python/sglang/srt/layers/attention/aiter_backend.py @@ -1064,7 +1064,7 @@ class AiterMultiStepDraftBackend: device=model_runner.device, ) self.attn_backends = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( AiterAttnBackend( model_runner, @@ -1107,7 +1107,7 @@ class AiterMultiStepDraftBackend: self.page_size, ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) @@ -1141,7 +1141,7 @@ class AiterMultiStepDraftBackend: dtype=torch.int32, device=self.device, ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state( max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 4fae2cb1d..8d8f789d0 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -2320,7 +2320,7 @@ class FlashAttentionMultiStepBackend: self.topk = topk self.speculative_num_steps = speculative_num_steps self.attn_backends = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( FlashAttentionBackend( model_runner, @@ -2335,7 +2335,7 @@ class FlashAttentionMultiStepBackend: self.attn_backends[i].init_forward_metadata(forward_batch) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph( diff --git a/python/sglang/srt/layers/attention/flashinfer_backend.py b/python/sglang/srt/layers/attention/flashinfer_backend.py index 7cae8e59d..aeb06bfa9 100644 --- a/python/sglang/srt/layers/attention/flashinfer_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_backend.py @@ -1405,7 +1405,7 @@ class FlashInferMultiStepDraftBackend: (max_bs,), dtype=torch.int32, device=model_runner.device ) self.attn_backends: List[FlashInferAttnBackend] = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( FlashInferAttnBackend( model_runner, @@ -1493,7 +1493,7 @@ class FlashInferMultiStepDraftBackend: device="cuda", ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state( max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 82d1b05b4..ad9cbfd44 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -916,7 +916,7 @@ class FlashInferMLAMultiStepDraftBackend: ) self.attn_backends = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( FlashInferMLAAttnBackend( model_runner, @@ -998,7 +998,7 @@ class FlashInferMLAMultiStepDraftBackend: device="cuda", ) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state( max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] ) diff --git a/python/sglang/srt/layers/attention/flashmla_backend.py b/python/sglang/srt/layers/attention/flashmla_backend.py index d85222806..81bbde7a5 100644 --- a/python/sglang/srt/layers/attention/flashmla_backend.py +++ b/python/sglang/srt/layers/attention/flashmla_backend.py @@ -478,7 +478,7 @@ class FlashMLAMultiStepDraftBackend: ) self.attn_backends = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( FlashMLABackend( model_runner, @@ -506,7 +506,7 @@ class FlashMLAMultiStepDraftBackend: self.common_template(forward_batch, call_fn) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state( max_bs, max_num_tokens, block_kv_indices=None ) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 4fab75700..97b869473 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -918,7 +918,7 @@ class TritonMultiStepDraftBackend: device=model_runner.device, ) self.attn_backends: List[TritonAttnBackend] = [] - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends.append( TritonAttnBackend( model_runner, @@ -969,7 +969,7 @@ class TritonMultiStepDraftBackend: if call_fn is None: return - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ : seq_lens_sum * self.topk + bs * (i + 1) @@ -1009,7 +1009,8 @@ class TritonMultiStepDraftBackend: dtype=torch.int32, device=self.device, ) - for i in range(self.speculative_num_steps): + + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state( max_bs, max_num_tokens, diff --git a/python/sglang/srt/layers/attention/trtllm_mha_backend.py b/python/sglang/srt/layers/attention/trtllm_mha_backend.py index 454a388f9..427dc5c67 100644 --- a/python/sglang/srt/layers/attention/trtllm_mha_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mha_backend.py @@ -637,7 +637,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): self, model_runner: ModelRunner, topk: int, speculative_num_steps: int ): super().__init__(model_runner, topk, speculative_num_steps) - for i in range(speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i] = TRTLLMHAAttnBackend( model_runner, skip_prefill=True, @@ -651,7 +651,7 @@ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend): self.attn_backends[i].init_forward_metadata(forward_batch) def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int): - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens) def init_forward_metadata_capture_cuda_graph( diff --git a/python/sglang/srt/layers/attention/trtllm_mla_backend.py b/python/sglang/srt/layers/attention/trtllm_mla_backend.py index 3727524ef..4943eed90 100755 --- a/python/sglang/srt/layers/attention/trtllm_mla_backend.py +++ b/python/sglang/srt/layers/attention/trtllm_mla_backend.py @@ -735,7 +735,7 @@ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend): ): super().__init__(model_runner, topk, speculative_num_steps) - for i in range(self.speculative_num_steps): + for i in range(self.speculative_num_steps - 1): self.attn_backends[i] = TRTLLMMLABackend( model_runner, skip_prefill=True, diff --git a/python/sglang/srt/speculative/draft_utils.py b/python/sglang/srt/speculative/draft_utils.py new file mode 100644 index 000000000..fd856b61e --- /dev/null +++ b/python/sglang/srt/speculative/draft_utils.py @@ -0,0 +1,222 @@ +import logging + +from sglang.srt.server_args import ServerArgs, get_global_server_args +from sglang.srt.utils.common import is_blackwell + +logger = logging.getLogger(__name__) + + +class DraftBackendFactory: + def __init__( + self, + server_args: ServerArgs, + draft_model_runner, + topk: int, + speculative_num_steps: int, + ): + self.server_args = server_args + self.draft_model_runner = draft_model_runner + self.topk = topk + self.speculative_num_steps = speculative_num_steps + + def _create_backend( + self, backend_name: str, backend_map: dict, error_template: str + ): + backend_type = getattr(self.server_args, backend_name) + if backend_type is None: + backend_type = self.server_args.attention_backend + + if backend_type not in backend_map: + raise ValueError(error_template.format(backend_type=backend_type)) + + return backend_map[backend_type]() + + def create_decode_backend(self): + if self.speculative_num_steps == 1: + + class DummyAttnBackend: + def __init__(self): + pass + + def init_forward_metadata(*args, **kwargs): + pass + + return DummyAttnBackend() + + backend_map = { + "flashinfer": self._create_flashinfer_decode_backend, + "triton": self._create_triton_decode_backend, + "aiter": self._create_aiter_decode_backend, + "fa3": self._create_fa3_decode_backend, + "hybrid_linear_attn": ( + self._create_fa3_decode_backend + if not is_blackwell() + else self._create_triton_decode_backend + ), + "flashmla": self._create_flashmla_decode_backend, + "trtllm_mha": self._create_trtllm_mha_decode_backend, + "trtllm_mla": self._create_trtllm_mla_decode_backend, + } + + return self._create_backend( + "decode_attention_backend", + backend_map, + "EAGLE is not supported in decode attention backend {backend_type}", + ) + + def create_draft_extend_backend(self): + backend_map = { + "flashinfer": self._create_flashinfer_prefill_backend, + "triton": self._create_triton_prefill_backend, + "aiter": self._create_aiter_prefill_backend, + "fa3": self._create_fa3_prefill_backend, + "hybrid_linear_attn": ( + self._create_fa3_prefill_backend + if not is_blackwell() + else self._create_triton_prefill_backend + ), + "flashmla": self._create_flashmla_prefill_backend, + "trtllm_mha": self._create_trtllm_mha_prefill_backend, + "trtllm_mla": self._create_trtllm_mla_prefill_backend, + } + backend_name = ( + "decode_attention_backend" + if self.server_args.speculative_attention_mode == "decode" + else "prefill_attention_backend" + ) + return self._create_backend( + backend_name, + backend_map, + "EAGLE is not supported in attention backend {backend_type}", + ) + + def _create_flashinfer_decode_backend(self): + if not get_global_server_args().use_mla_backend: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return FlashInferMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return FlashInferMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_triton_decode_backend(self): + from sglang.srt.layers.attention.triton_backend import ( + TritonMultiStepDraftBackend, + ) + + return TritonMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_aiter_decode_backend(self): + from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend + + return AiterMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_fa3_decode_backend(self): + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionMultiStepBackend, + ) + + return FlashAttentionMultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_flashmla_decode_backend(self): + from sglang.srt.layers.attention.flashmla_backend import ( + FlashMLAMultiStepDraftBackend, + ) + + return FlashMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_trtllm_mha_decode_backend(self): + from sglang.srt.layers.attention.trtllm_mha_backend import ( + TRTLLMHAAttnMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return TRTLLMHAAttnMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_trtllm_mla_decode_backend(self): + if not get_global_server_args().use_mla_backend: + raise ValueError( + "trtllm_mla backend requires MLA model (use_mla_backend=True)." + ) + + from sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLAMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return TRTLLMMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_flashinfer_prefill_backend(self): + if not get_global_server_args().use_mla_backend: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + ) + + return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + + return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_triton_prefill_backend(self): + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + return TritonAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_aiter_prefill_backend(self): + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + return AiterAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_fa3_prefill_backend(self): + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False) + + def _create_trtllm_mha_prefill_backend(self): + from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend + + return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_trtllm_mla_prefill_backend(self): + if not get_global_server_args().use_mla_backend: + raise ValueError( + "trtllm_mla backend requires MLA model (use_mla_backend=True)." + ) + + from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend + + return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) + + def _create_flashmla_prefill_backend(self): + logger.warning( + "flashmla prefill backend is not yet supported for draft extend." + ) + return None diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index a8461c999..08282e533 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -27,7 +27,8 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, ) -from sglang.srt.server_args import ServerArgs, get_global_server_args +from sglang.srt.server_args import ServerArgs +from sglang.srt.speculative.draft_utils import DraftBackendFactory from sglang.srt.speculative.eagle_draft_cuda_graph_runner import ( EAGLEDraftCudaGraphRunner, ) @@ -195,205 +196,23 @@ class EAGLEWorker(TpModelWorker): self.has_prefill_wrapper_verify = False self.draft_extend_attn_backend = None + draft_backend_factory = DraftBackendFactory( + self.server_args, + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + # Initialize decode attention backend - self.draft_attn_backend = self._create_decode_backend() + self.draft_attn_backend = draft_backend_factory.create_decode_backend() # Initialize draft extend attention backend (respects speculative_attention_mode setting) - self.draft_extend_attn_backend = self._create_draft_extend_backend() + self.draft_extend_attn_backend = ( + draft_backend_factory.create_draft_extend_backend() + ) self.draft_model_runner.draft_attn_backend = self.draft_attn_backend - def _create_backend( - self, backend_name: str, backend_map: dict, error_template: str - ): - backend_type = getattr(self.server_args, backend_name) - if backend_type is None: - backend_type = self.server_args.attention_backend - - if backend_type not in backend_map: - raise ValueError(error_template.format(backend_type=backend_type)) - - return backend_map[backend_type]() - - def _create_decode_backend(self): - backend_map = { - "flashinfer": self._create_flashinfer_decode_backend, - "triton": self._create_triton_decode_backend, - "aiter": self._create_aiter_decode_backend, - "fa3": self._create_fa3_decode_backend, - "hybrid_linear_attn": ( - self._create_fa3_decode_backend - if not is_blackwell() - else self._create_triton_decode_backend - ), - "flashmla": self._create_flashmla_decode_backend, - "trtllm_mha": self._create_trtllm_mha_decode_backend, - "trtllm_mla": self._create_trtllm_mla_decode_backend, - } - - return self._create_backend( - "decode_attention_backend", - backend_map, - "EAGLE is not supported in decode attention backend {backend_type}", - ) - - def _create_draft_extend_backend(self): - backend_map = { - "flashinfer": self._create_flashinfer_prefill_backend, - "triton": self._create_triton_prefill_backend, - "aiter": self._create_aiter_prefill_backend, - "fa3": self._create_fa3_prefill_backend, - "hybrid_linear_attn": ( - self._create_fa3_prefill_backend - if not is_blackwell() - else self._create_triton_prefill_backend - ), - "flashmla": self._create_flashmla_prefill_backend, - "trtllm_mha": self._create_trtllm_mha_prefill_backend, - "trtllm_mla": self._create_trtllm_mla_prefill_backend, - } - backend_name = ( - "decode_attention_backend" - if self.server_args.speculative_attention_mode == "decode" - else "prefill_attention_backend" - ) - return self._create_backend( - backend_name, - backend_map, - "EAGLE is not supported in attention backend {backend_type}", - ) - - def _create_flashinfer_decode_backend(self): - if not get_global_server_args().use_mla_backend: - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferMultiStepDraftBackend, - ) - - self.has_prefill_wrapper_verify = True - return FlashInferMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - else: - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAMultiStepDraftBackend, - ) - - self.has_prefill_wrapper_verify = True - return FlashInferMLAMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_triton_decode_backend(self): - from sglang.srt.layers.attention.triton_backend import ( - TritonMultiStepDraftBackend, - ) - - return TritonMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_aiter_decode_backend(self): - from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend - - return AiterMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_fa3_decode_backend(self): - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionMultiStepBackend, - ) - - return FlashAttentionMultiStepBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_flashmla_decode_backend(self): - from sglang.srt.layers.attention.flashmla_backend import ( - FlashMLAMultiStepDraftBackend, - ) - - return FlashMLAMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_trtllm_mha_decode_backend(self): - from sglang.srt.layers.attention.trtllm_mha_backend import ( - TRTLLMHAAttnMultiStepDraftBackend, - ) - - self.has_prefill_wrapper_verify = True - return TRTLLMHAAttnMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_trtllm_mla_decode_backend(self): - if not get_global_server_args().use_mla_backend: - raise ValueError( - "trtllm_mla backend requires MLA model (use_mla_backend=True)." - ) - - from sglang.srt.layers.attention.trtllm_mla_backend import ( - TRTLLMMLAMultiStepDraftBackend, - ) - - self.has_prefill_wrapper_verify = True - return TRTLLMMLAMultiStepDraftBackend( - self.draft_model_runner, self.topk, self.speculative_num_steps - ) - - def _create_flashinfer_prefill_backend(self): - if not get_global_server_args().use_mla_backend: - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferAttnBackend, - ) - - return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False) - else: - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAAttnBackend, - ) - - return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False) - - def _create_triton_prefill_backend(self): - from sglang.srt.layers.attention.triton_backend import TritonAttnBackend - - return TritonAttnBackend(self.draft_model_runner, skip_prefill=False) - - def _create_aiter_prefill_backend(self): - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend - - return AiterAttnBackend(self.draft_model_runner, skip_prefill=False) - - def _create_fa3_prefill_backend(self): - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - ) - - return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False) - - def _create_trtllm_mha_prefill_backend(self): - from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend - - return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) - - def _create_trtllm_mla_prefill_backend(self): - if not get_global_server_args().use_mla_backend: - raise ValueError( - "trtllm_mla backend requires MLA model (use_mla_backend=True)." - ) - - from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend - - return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) - - def _create_flashmla_prefill_backend(self): - logger.warning( - "flashmla prefill backend is not yet supported for draft extend." - ) - return None - def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None