Add draft extend CUDA graph for Triton backend (#6705)
This commit is contained in:
@@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||
|
||||
self.num_head = (
|
||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||
@@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend):
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
elif forward_mode.is_draft_extend():
|
||||
num_tokens_per_bs = self.speculative_num_steps + 1
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
bs * num_tokens_per_bs + 1,
|
||||
step=num_tokens_per_bs,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
max_extend_len = num_tokens_per_bs
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||
@@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend):
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
elif forward_mode.is_draft_extend():
|
||||
seq_lens = seq_lens[:bs]
|
||||
accept_lens = spec_info.accept_length[:bs]
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||
|
||||
@@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.has_prefill_wrapper_verify = True
|
||||
elif self.server_args.attention_backend == "triton":
|
||||
from sglang.srt.layers.attention.triton_backend import (
|
||||
TritonAttnBackend,
|
||||
TritonMultiStepDraftBackend,
|
||||
)
|
||||
|
||||
@@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker):
|
||||
self.topk,
|
||||
self.speculative_num_steps,
|
||||
)
|
||||
self.draft_extend_attn_backend = None
|
||||
self.draft_extend_attn_backend = TritonAttnBackend(
|
||||
self.draft_model_runner,
|
||||
skip_prefill=False,
|
||||
)
|
||||
self.padded_static_len = self.speculative_num_steps + 1
|
||||
self.has_prefill_wrapper_verify = False
|
||||
elif self.server_args.attention_backend == "fa3":
|
||||
|
||||
Reference in New Issue
Block a user