Add draft extend CUDA graph for flashinfer backend (#6805)
This commit is contained in:
@@ -156,6 +156,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
@@ -164,8 +165,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
@@ -174,7 +180,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
|
||||
Reference in New Issue
Block a user