Add draft extend CUDA graph for flashinfer backend (#6805)

This commit is contained in:
Ke Bao
2025-06-02 16:51:26 +08:00
committed by GitHub
parent 55444ed667
commit a2cb5913a0
5 changed files with 170 additions and 3 deletions

View File

@@ -156,6 +156,7 @@ class EAGLEWorker(TpModelWorker):
if self.server_args.attention_backend == "flashinfer":
if not global_server_args_dict["use_mla_backend"]:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
FlashInferMultiStepDraftBackend,
)
@@ -164,8 +165,13 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = FlashInferAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
FlashInferMLAMultiStepDraftBackend,
)
@@ -174,7 +180,10 @@ class EAGLEWorker(TpModelWorker):
self.topk,
self.speculative_num_steps,
)
self.draft_extend_attn_backend = None
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
self.draft_model_runner,
skip_prefill=False,
)
self.padded_static_len = self.speculative_num_steps + 1
self.has_prefill_wrapper_verify = True
elif self.server_args.attention_backend == "triton":