Add draft extend CUDA graph for flashinfer backend (#6805)
This commit is contained in:
@@ -358,6 +358,35 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
||||
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
||||
elif forward_mode.is_draft_extend():
|
||||
prefill_wrappers = []
|
||||
for i in range(self.num_wrappers):
|
||||
prefill_wrappers.append(
|
||||
BatchPrefillWithPagedKVCacheWrapper(
|
||||
self.workspace_buffer,
|
||||
"NHD",
|
||||
backend="fa2",
|
||||
use_cuda_graph=True,
|
||||
qo_indptr_buf=self.cuda_graph_qo_indptr[i][: bs + 1],
|
||||
paged_kv_indptr_buf=self.kv_indptr[i][: bs + 1],
|
||||
paged_kv_indices_buf=self.cuda_graph_kv_indices[i],
|
||||
paged_kv_last_page_len_buf=self.kv_last_page_len[:bs],
|
||||
)
|
||||
)
|
||||
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=prefill_wrappers,
|
||||
use_ragged=False,
|
||||
encoder_lens=encoder_lens,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = prefill_wrappers
|
||||
self.forward_metadata = PrefillMetadata(prefill_wrappers, False, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
@@ -392,6 +421,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrappers=self.prefill_cuda_graph_metadata[bs],
|
||||
use_ragged=False,
|
||||
encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Invalid forward mode")
|
||||
|
||||
|
||||
@@ -278,6 +278,28 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = verify_wrapper
|
||||
self.forward_metadata = PrefillMetadata(verify_wrapper, False)
|
||||
elif forward_mode.is_draft_extend():
|
||||
draft_extend_wrapper = BatchMLAPagedAttentionWrapper(
|
||||
self.workspace_buffer,
|
||||
use_cuda_graph=True,
|
||||
qo_indptr=self.cuda_graph_qo_indptr[: bs + 1],
|
||||
kv_indptr=self.cuda_graph_kv_indptr[: bs + 1],
|
||||
kv_indices=self.cuda_graph_kv_indices,
|
||||
kv_len_arr=self.cuda_graph_kv_lens[:bs],
|
||||
backend="auto",
|
||||
)
|
||||
seq_lens_sum = seq_lens.sum().item()
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=draft_extend_wrapper,
|
||||
use_ragged=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
self.prefill_cuda_graph_metadata[bs] = draft_extend_wrapper
|
||||
self.forward_metadata = PrefillMetadata(draft_extend_wrapper, False)
|
||||
else:
|
||||
raise ValueError(f"Invalid mode: {forward_mode=}")
|
||||
|
||||
@@ -325,6 +347,16 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
use_ragged=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
elif forward_mode.is_draft_extend():
|
||||
self.indices_updater_prefill.update(
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
seq_lens_sum,
|
||||
prefix_lens=None,
|
||||
prefill_wrapper_paged=self.prefill_cuda_graph_metadata[bs],
|
||||
use_ragged=False,
|
||||
spec_info=spec_info,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid forward mode: {forward_mode=}")
|
||||
|
||||
|
||||
@@ -80,7 +80,9 @@ class EAGLEDraftExtendCudaGraphRunner:
|
||||
|
||||
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.extend_seq_lens = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.accept_length = torch.ones((self.max_bs,), dtype=torch.int32)
|
||||
self.accept_length = (
|
||||
torch.ones((self.max_bs,), dtype=torch.int32) * self.num_tokens_per_bs
|
||||
)
|
||||
|
||||
# Capture
|
||||
try:
|
||||
|
||||
@@ -156,6 +156,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
if self.server_args.attention_backend == "flashinfer":
|
||||
if not global_server_args_dict["use_mla_backend"]:
|
||||
from sglang.srt.layers.attention.flashinfer_backend import (
|
||||
FlashInferAttnBackend,
|
||||
FlashInferMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
@@ -164,8 +165,13 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = FlashInferAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
else:
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
||||
FlashInferMLAAttnBackend,
|
||||
FlashInferMLAMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
@@ -174,7 +180,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.draft_extend_attn_backend = FlashInferMLAAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
|
||||
Reference in New Issue
Block a user