Add draft extend CUDA graph for Triton backend (#6705)
This commit is contained in:
@@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
)
|
)
|
||||||
|
|
||||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||||
|
self.speculative_num_steps = model_runner.server_args.speculative_num_steps
|
||||||
|
|
||||||
self.num_head = (
|
self.num_head = (
|
||||||
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
||||||
@@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
num_kv_splits = None
|
num_kv_splits = None
|
||||||
attn_logits = None
|
attn_logits = None
|
||||||
attn_lse = None
|
attn_lse = None
|
||||||
|
elif forward_mode.is_draft_extend():
|
||||||
|
num_tokens_per_bs = self.speculative_num_steps + 1
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[: bs + 1] = torch.arange(
|
||||||
|
0,
|
||||||
|
bs * num_tokens_per_bs + 1,
|
||||||
|
step=num_tokens_per_bs,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=self.device,
|
||||||
|
)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
|
custom_mask = None
|
||||||
|
mask_indptr = None
|
||||||
|
max_extend_len = num_tokens_per_bs
|
||||||
|
num_kv_splits = None
|
||||||
|
attn_logits = None
|
||||||
|
attn_lse = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||||
@@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend):
|
|||||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||||
mask_indptr = self.mask_indptr[: bs + 1]
|
mask_indptr = self.mask_indptr[: bs + 1]
|
||||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||||
|
elif forward_mode.is_draft_extend():
|
||||||
|
seq_lens = seq_lens[:bs]
|
||||||
|
accept_lens = spec_info.accept_length[:bs]
|
||||||
|
qo_indptr = self.qo_indptr[: bs + 1]
|
||||||
|
qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0)
|
||||||
|
kv_indptr = self.kv_indptr[: bs + 1]
|
||||||
|
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||||
|
kv_indices = self.cuda_graph_kv_indices
|
||||||
|
create_flashinfer_kv_indices_triton[(bs,)](
|
||||||
|
self.req_to_token,
|
||||||
|
req_pool_indices,
|
||||||
|
seq_lens,
|
||||||
|
kv_indptr,
|
||||||
|
None,
|
||||||
|
kv_indices,
|
||||||
|
self.req_to_token.stride(0),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||||
|
|||||||
@@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.has_prefill_wrapper_verify = True
|
self.has_prefill_wrapper_verify = True
|
||||||
elif self.server_args.attention_backend == "triton":
|
elif self.server_args.attention_backend == "triton":
|
||||||
from sglang.srt.layers.attention.triton_backend import (
|
from sglang.srt.layers.attention.triton_backend import (
|
||||||
|
TritonAttnBackend,
|
||||||
TritonMultiStepDraftBackend,
|
TritonMultiStepDraftBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker):
|
|||||||
self.topk,
|
self.topk,
|
||||||
self.speculative_num_steps,
|
self.speculative_num_steps,
|
||||||
)
|
)
|
||||||
self.draft_extend_attn_backend = None
|
self.draft_extend_attn_backend = TritonAttnBackend(
|
||||||
|
self.draft_model_runner,
|
||||||
|
skip_prefill=False,
|
||||||
|
)
|
||||||
self.padded_static_len = self.speculative_num_steps + 1
|
self.padded_static_len = self.speculative_num_steps + 1
|
||||||
self.has_prefill_wrapper_verify = False
|
self.has_prefill_wrapper_verify = False
|
||||||
elif self.server_args.attention_backend == "fa3":
|
elif self.server_args.attention_backend == "fa3":
|
||||||
|
|||||||
Reference in New Issue
Block a user