[feat] Add P/D attention select for draft model (#9755)
Co-authored-by: 纬杭 <ximing.wxm@antgroup.com>
This commit is contained in:
@@ -187,138 +187,184 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.has_prefill_wrapper_verify = False
|
self.has_prefill_wrapper_verify = False
|
||||||
self.draft_extend_attn_backend = None
|
self.draft_extend_attn_backend = None
|
||||||
|
|
||||||
if self.server_args.attention_backend == "flashinfer":
|
# Initialize decode attention backend
|
||||||
if not global_server_args_dict["use_mla_backend"]:
|
self.draft_attn_backend = self._create_decode_backend()
|
||||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
|
||||||
FlashInferAttnBackend,
|
|
||||||
FlashInferMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = FlashInferMultiStepDraftBackend(
|
# Initialize prefill attention backend
|
||||||
self.draft_model_runner,
|
self.draft_extend_attn_backend = self._create_draft_extend_backend()
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
|
||||||
FlashInferMLAAttnBackend,
|
|
||||||
FlashInferMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
elif self.server_args.attention_backend == "triton":
|
|
||||||
from sglang.srt.layers.attention.triton_backend import (
|
|
||||||
TritonAttnBackend,
|
|
||||||
TritonMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = TritonMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = TritonAttnBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
elif self.server_args.attention_backend == "aiter":
|
|
||||||
from sglang.srt.layers.attention.aiter_backend import (
|
|
||||||
AiterAttnBackend,
|
|
||||||
AiterMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = AiterMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = AiterAttnBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
self.has_prefill_wrapper_verify = False
|
|
||||||
elif self.server_args.attention_backend == "fa3":
|
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
|
||||||
FlashAttentionBackend,
|
|
||||||
FlashAttentionMultiStepBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = FlashAttentionMultiStepBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = FlashAttentionBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
elif self.server_args.attention_backend == "flashmla":
|
|
||||||
from sglang.srt.layers.attention.flashmla_backend import (
|
|
||||||
FlashMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = FlashMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
elif self.server_args.attention_backend == "trtllm_mha":
|
|
||||||
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
|
||||||
TRTLLMHAAttnBackend,
|
|
||||||
TRTLLMHAAttnMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = TRTLLMHAAttnMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = TRTLLMHAAttnBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
elif self.server_args.attention_backend == "trtllm_mla":
|
|
||||||
if not global_server_args_dict["use_mla_backend"]:
|
|
||||||
raise ValueError(
|
|
||||||
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
|
||||||
)
|
|
||||||
|
|
||||||
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
|
||||||
TRTLLMMLABackend,
|
|
||||||
TRTLLMMLAMultiStepDraftBackend,
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_attn_backend = TRTLLMMLAMultiStepDraftBackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
self.topk,
|
|
||||||
self.speculative_num_steps,
|
|
||||||
)
|
|
||||||
self.draft_extend_attn_backend = TRTLLMMLABackend(
|
|
||||||
self.draft_model_runner,
|
|
||||||
skip_prefill=False,
|
|
||||||
)
|
|
||||||
self.has_prefill_wrapper_verify = True
|
|
||||||
else:
|
|
||||||
raise ValueError(
|
|
||||||
f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
|
|
||||||
)
|
|
||||||
|
|
||||||
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
||||||
|
|
||||||
|
def _create_backend(
|
||||||
|
self, backend_name: str, backend_map: dict, error_template: str
|
||||||
|
):
|
||||||
|
backend_type = getattr(self.server_args, backend_name)
|
||||||
|
if backend_type is None:
|
||||||
|
backend_type = self.server_args.attention_backend
|
||||||
|
|
||||||
|
if backend_type not in backend_map:
|
||||||
|
raise ValueError(error_template.format(backend_type=backend_type))
|
||||||
|
|
||||||
|
return backend_map[backend_type]()
|
||||||
|
|
||||||
|
def _create_decode_backend(self):
|
||||||
|
backend_map = {
|
||||||
|
"flashinfer": self._create_flashinfer_decode_backend,
|
||||||
|
"triton": self._create_triton_decode_backend,
|
||||||
|
"aiter": self._create_aiter_decode_backend,
|
||||||
|
"fa3": self._create_fa3_decode_backend,
|
||||||
|
"flashmla": self._create_flashmla_decode_backend,
|
||||||
|
"trtllm_mha": self._create_trtllm_mha_decode_backend,
|
||||||
|
"trtllm_mla": self._create_trtllm_mla_decode_backend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._create_backend(
|
||||||
|
"decode_attention_backend",
|
||||||
|
backend_map,
|
||||||
|
"EAGLE is not supported in decode attention backend {backend_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_draft_extend_backend(self):
|
||||||
|
backend_map = {
|
||||||
|
"flashinfer": self._create_flashinfer_prefill_backend,
|
||||||
|
"triton": self._create_triton_prefill_backend,
|
||||||
|
"aiter": self._create_aiter_prefill_backend,
|
||||||
|
"fa3": self._create_fa3_prefill_backend,
|
||||||
|
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
|
||||||
|
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
|
||||||
|
}
|
||||||
|
|
||||||
|
return self._create_backend(
|
||||||
|
"prefill_attention_backend",
|
||||||
|
backend_map,
|
||||||
|
"EAGLE is not supported in prefill attention backend {backend_type}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashinfer_decode_backend(self):
|
||||||
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
FlashInferMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return FlashInferMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return FlashInferMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_triton_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.triton_backend import (
|
||||||
|
TritonMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return TritonMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_aiter_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import AiterMultiStepDraftBackend
|
||||||
|
|
||||||
|
return AiterMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_fa3_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionMultiStepBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashAttentionMultiStepBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashmla_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashmla_backend import (
|
||||||
|
FlashMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_trtllm_mha_decode_backend(self):
|
||||||
|
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
||||||
|
TRTLLMHAAttnMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return TRTLLMHAAttnMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_trtllm_mla_decode_backend(self):
|
||||||
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
|
raise ValueError(
|
||||||
|
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.trtllm_mla_backend import (
|
||||||
|
TRTLLMMLAMultiStepDraftBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.has_prefill_wrapper_verify = True
|
||||||
|
return TRTLLMMLAMultiStepDraftBackend(
|
||||||
|
self.draft_model_runner, self.topk, self.speculative_num_steps
|
||||||
|
)
|
||||||
|
|
||||||
|
def _create_flashinfer_prefill_backend(self):
|
||||||
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||||
|
FlashInferAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashInferAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
else:
|
||||||
|
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||||
|
FlashInferMLAAttnBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashInferMLAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_triton_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||||
|
|
||||||
|
return TritonAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_aiter_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
||||||
|
|
||||||
|
return AiterAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_fa3_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.flashattention_backend import (
|
||||||
|
FlashAttentionBackend,
|
||||||
|
)
|
||||||
|
|
||||||
|
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_trtllm_mha_prefill_backend(self):
|
||||||
|
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
||||||
|
|
||||||
|
return TRTLLMHAAttnBackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
|
def _create_trtllm_mla_prefill_backend(self):
|
||||||
|
if not global_server_args_dict["use_mla_backend"]:
|
||||||
|
raise ValueError(
|
||||||
|
"trtllm_mla backend requires MLA model (use_mla_backend=True)."
|
||||||
|
)
|
||||||
|
|
||||||
|
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
||||||
|
|
||||||
|
return TRTLLMMLABackend(self.draft_model_runner, skip_prefill=False)
|
||||||
|
|
||||||
def init_cuda_graphs(self):
|
def init_cuda_graphs(self):
|
||||||
"""Capture cuda graphs."""
|
"""Capture cuda graphs."""
|
||||||
self.cuda_graph_runner = None
|
self.cuda_graph_runner = None
|
||||||
|
|||||||
Reference in New Issue
Block a user