diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 24e3eca95..56c120a0f 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -187,138 +187,184 @@ class EAGLEWorker(TpModelWorker): self.has_prefill_wrapper_verify = False self.draft_extend_attn_backend = None - if self.server_args.attention_backend == "flashinfer": - if not global_server_args_dict["use_mla_backend"]: - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferAttnBackend, - FlashInferMultiStepDraftBackend, - ) + # Initialize decode attention backend + self.draft_attn_backend = self._create_decode_backend() - self.draft_attn_backend = FlashInferMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = FlashInferAttnBackend( - self.draft_model_runner, - skip_prefill=False, - ) - else: - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAAttnBackend, - FlashInferMLAMultiStepDraftBackend, - ) - - self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = FlashInferMLAAttnBackend( - self.draft_model_runner, - skip_prefill=False, - ) - self.has_prefill_wrapper_verify = True - elif self.server_args.attention_backend == "triton": - from sglang.srt.layers.attention.triton_backend import ( - TritonAttnBackend, - TritonMultiStepDraftBackend, - ) - - self.draft_attn_backend = TritonMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = TritonAttnBackend( - self.draft_model_runner, - skip_prefill=False, - ) - elif self.server_args.attention_backend == "aiter": - from sglang.srt.layers.attention.aiter_backend import ( - AiterAttnBackend, - AiterMultiStepDraftBackend, - ) - - self.draft_attn_backend = AiterMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = AiterAttnBackend( - self.draft_model_runner, - skip_prefill=False, - ) - self.has_prefill_wrapper_verify = False - elif self.server_args.attention_backend == "fa3": - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - FlashAttentionMultiStepBackend, - ) - - self.draft_attn_backend = FlashAttentionMultiStepBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = FlashAttentionBackend( - self.draft_model_runner, - skip_prefill=False, - ) - elif self.server_args.attention_backend == "flashmla": - from sglang.srt.layers.attention.flashmla_backend import ( - FlashMLAMultiStepDraftBackend, - ) - - self.draft_attn_backend = FlashMLAMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - elif self.server_args.attention_backend == "trtllm_mha": - from sglang.srt.layers.attention.trtllm_mha_backend import ( - TRTLLMHAAttnBackend, - TRTLLMHAAttnMultiStepDraftBackend, - ) - - self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = TRTLLMHAAttnBackend( - self.draft_model_runner, - skip_prefill=False, - ) - self.has_prefill_wrapper_verify = True - elif self.server_args.attention_backend == "trtllm_mla": - if not global_server_args_dict["use_mla_backend"]: - raise ValueError( - "trtllm_mla backend requires MLA model (use_mla_backend=True)." - ) - - from sglang.srt.layers.attention.trtllm_mla_backend import ( - TRTLLMMLABackend, - TRTLLMMLAMultiStepDraftBackend, - ) - - self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = TRTLLMMLABackend( - self.draft_model_runner, - skip_prefill=False, - ) - self.has_prefill_wrapper_verify = True - else: - raise ValueError( - f"EAGLE is not supported in attention backend {self.server_args.attention_backend}" - ) + # Initialize prefill attention backend + self.draft_extend_attn_backend = self._create_draft_extend_backend() self.draft_model_runner.draft_attn_backend = self.draft_attn_backend + def _create_backend( + self, backend_name: str, backend_map: dict, error_template: str + ): + backend_type = getattr(self.server_args, backend_name) + if backend_type is None: + backend_type = self.server_args.attention_backend + + if backend_type not in backend_map: + raise ValueError(error_template.format(backend_type=backend_type)) + + return backend_map[backend_type]() + + def _create_decode_backend(self): + backend_map = { + "flashinfer": self._create_flashinfer_decode_backend, + "triton": self._create_triton_decode_backend, + "aiter": self._create_aiter_decode_backend, + "fa3": self._create_fa3_decode_backend, + "flashmla": self._create_flashmla_decode_backend, + "trtllm_mha": self._create_trtllm_mha_decode_backend, + "trtllm_mla": self._create_trtllm_mla_decode_backend, + } + + return self._create_backend( + "decode_attention_backend", + backend_map, + "EAGLE is not supported in decode attention backend {backend_type}", + ) + + def _create_draft_extend_backend(self): + backend_map = { + "flashinfer": self._create_flashinfer_prefill_backend, + "triton": self._create_triton_prefill_backend, + "aiter": self._create_aiter_prefill_backend, + "fa3": self._create_fa3_prefill_backend, + "trtllm_mha": self._create_trtllm_mha_prefill_backend, + "trtllm_mla": self._create_trtllm_mla_prefill_backend, + } + + return self._create_backend( + "prefill_attention_backend", + backend_map, + "EAGLE is not supported in prefill attention backend {backend_type}", + ) + + def _create_flashinfer_decode_backend(self): + if not global_server_args_dict["use_mla_backend"]: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return FlashInferMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return FlashInferMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_triton_decode_backend(self): + from sglang.srt.layers.attention.triton_backend import ( + TritonMultiStepDraftBackend, + ) + + return TritonMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_aiter_decode_backend(self): + from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend + + return AiterMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_fa3_decode_backend(self): + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionMultiStepBackend, + ) + + return FlashAttentionMultiStepBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_flashmla_decode_backend(self): + from sglang.srt.layers.attention.flashmla_backend import ( + FlashMLAMultiStepDraftBackend, + ) + + return FlashMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_trtllm_mha_decode_backend(self): + from sglang.srt.layers.attention.trtllm_mha_backend import ( + TRTLLMHAAttnMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return TRTLLMHAAttnMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_trtllm_mla_decode_backend(self): + if not global_server_args_dict["use_mla_backend"]: + raise ValueError( + "trtllm_mla backend requires MLA model (use_mla_backend=True)." + ) + + from sglang.srt.layers.attention.trtllm_mla_backend import ( + TRTLLMMLAMultiStepDraftBackend, + ) + + self.has_prefill_wrapper_verify = True + return TRTLLMMLAMultiStepDraftBackend( + self.draft_model_runner, self.topk, self.speculative_num_steps + ) + + def _create_flashinfer_prefill_backend(self): + if not global_server_args_dict["use_mla_backend"]: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + ) + + return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + + return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_triton_prefill_backend(self): + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + return TritonAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_aiter_prefill_backend(self): + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + return AiterAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_fa3_prefill_backend(self): + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False) + + def _create_trtllm_mha_prefill_backend(self): + from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend + + return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False) + + def _create_trtllm_mla_prefill_backend(self): + if not global_server_args_dict["use_mla_backend"]: + raise ValueError( + "trtllm_mla backend requires MLA model (use_mla_backend=True)." + ) + + from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend + + return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False) + def init_cuda_graphs(self): """Capture cuda graphs.""" self.cuda_graph_runner = None