Support Eagle cuda graph for Triton backend (#3500)
This commit is contained in:
@@ -38,6 +38,8 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.decode_attention_fwd = decode_attention_fwd
|
||||
self.extend_attention_fwd = extend_attention_fwd
|
||||
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
max_bs = model_runner.req_to_token_pool.size
|
||||
|
||||
if kv_indptr_buf is None:
|
||||
@@ -48,13 +50,15 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.kv_indptr = kv_indptr_buf
|
||||
|
||||
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.mask_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
if not self.skip_prefill:
|
||||
self.qo_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
||||
)
|
||||
|
||||
self.mask_indptr = torch.zeros(
|
||||
(max_bs + 1,), dtype=torch.int64, device=model_runner.device
|
||||
)
|
||||
|
||||
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
|
||||
|
||||
@@ -196,22 +200,29 @@ class TritonAttnBackend(AttentionBackend):
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.cuda_graph_max_total_num_tokens = max_bs * self.max_context_len
|
||||
|
||||
self.cuda_graph_start_loc = torch.zeros(
|
||||
(max_bs,), dtype=torch.int32, device=self.device
|
||||
)
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
):
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
(max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
else:
|
||||
self.cuda_graph_kv_indices = kv_indices_buf
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
@@ -224,31 +235,71 @@ class TritonAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
assert encoder_lens is None, "Not supported"
|
||||
assert forward_mode.is_decode(), "Not supported"
|
||||
assert spec_info is None, "Not supported"
|
||||
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
if spec_info is None:
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
|
||||
|
||||
attn_logits = self.cuda_graph_attn_logits
|
||||
max_extend_len = None
|
||||
qo_indptr = None
|
||||
custom_mask = None
|
||||
mask_indptr = None
|
||||
elif forward_mode.is_target_verify():
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
max_extend_len = self.num_draft_tokens
|
||||
attn_logits = None
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph capture."
|
||||
)
|
||||
|
||||
self.forward_metadata = (
|
||||
self.cuda_graph_attn_logits,
|
||||
None,
|
||||
attn_logits,
|
||||
max_extend_len,
|
||||
kv_indptr,
|
||||
kv_indices,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
qo_indptr,
|
||||
custom_mask,
|
||||
mask_indptr,
|
||||
)
|
||||
|
||||
def init_forward_metadata_replay_cuda_graph(
|
||||
@@ -262,22 +313,57 @@ class TritonAttnBackend(AttentionBackend):
|
||||
spec_info: Optional[SpecInfo],
|
||||
):
|
||||
# NOTE: encoder_lens expected to be zeros or None
|
||||
self.cuda_graph_start_loc.zero_()
|
||||
self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0)
|
||||
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
if forward_mode.is_decode_or_idle():
|
||||
# Update kv_indptr, kv_indices
|
||||
kv_indptr = self.kv_indptr
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
if spec_info is None:
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
|
||||
kv_indptr = kv_indptr[: bs + 1]
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices[:bs],
|
||||
seq_lens[:bs],
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
else:
|
||||
kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
|
||||
kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
|
||||
elif forward_mode.is_target_verify():
|
||||
# Update qo_indptr, kv_indptr, kv_indices, custom_mask, mask_indptr
|
||||
bs = len(req_pool_indices)
|
||||
qo_indptr = self.qo_indptr[: bs + 1]
|
||||
qo_indptr[: bs + 1] = torch.arange(
|
||||
0,
|
||||
(1 + bs) * self.num_draft_tokens,
|
||||
step=self.num_draft_tokens,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
kv_indptr = self.kv_indptr[: bs + 1]
|
||||
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
|
||||
kv_indices = self.cuda_graph_kv_indices
|
||||
create_flashinfer_kv_indices_triton[(bs,)](
|
||||
self.req_to_token,
|
||||
req_pool_indices,
|
||||
seq_lens,
|
||||
kv_indptr,
|
||||
None,
|
||||
kv_indices,
|
||||
self.req_to_token.stride(0),
|
||||
)
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid forward mode: {forward_mode=} for CUDA Graph replay."
|
||||
)
|
||||
|
||||
def get_cuda_graph_seq_len_fill_value(self):
|
||||
return 1
|
||||
@@ -407,6 +493,7 @@ class TritonMultiStepDraftBackend:
|
||||
)
|
||||
)
|
||||
self.max_context_len = self.attn_backends[0].max_context_len
|
||||
self.device = model_runner.device
|
||||
# Cached variables for generate_draft_decode_kv_indices
|
||||
self.pool_len = model_runner.req_to_token_pool.req_to_token.shape[1]
|
||||
|
||||
@@ -450,7 +537,7 @@ class TritonMultiStepDraftBackend:
|
||||
forward_batch.batch_size * self.topk * self.max_context_len,
|
||||
),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def call_fn(i, forward_batch):
|
||||
@@ -468,7 +555,7 @@ class TritonMultiStepDraftBackend:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
|
||||
Reference in New Issue
Block a user