feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None:
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)

View File

@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
"""Init the metadata for a forward pass."""
raise NotImplementedError()
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Init the global shared states for cuda graph."""
raise NotImplementedError()

View File

@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:

View File

@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
"""Initialize CUDA graph state for the attention backend.
Args:
@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
for i in range(self.speculative_num_steps - 1):
self.attn_backends[i].init_forward_metadata(forward_batch)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs)
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,

View File

@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len,),
(max_num_tokens * self.max_context_len,),
dtype=torch.int32,
device="cuda",
)
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device="cuda",
)
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):

View File

@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
if kv_indices_buf is None:
cuda_graph_kv_indices = torch.zeros(
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
dtype=torch.int32,
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):

View File

@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
def init_cuda_graph_state(
self,
max_bs: int,
max_num_tokens: int,
block_kv_indices: Optional[torch.Tensor] = None,
):
if block_kv_indices is None:
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
self.common_template(forward_batch, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
self.attn_backends[i].init_cuda_graph_state(
max_bs, max_num_tokens, block_kv_indices=None
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
def call_fn(i, forward_batch):

View File

@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,

View File

@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
num_kv_splits = None
attn_logits = None
attn_lse = None
elif forward_batch.forward_mode.is_draft_extend():
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
self,
max_bs: int,
max_num_tokens: int,
kv_indices_buf: Optional[torch.Tensor] = None,
):
self.cuda_graph_attn_logits = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_attn_lse = torch.zeros(
(max_bs, self.num_head, self.max_kv_splits),
(max_num_tokens, self.num_head, self.max_kv_splits),
dtype=torch.float32,
device=self.device,
)
self.cuda_graph_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
if not self.skip_prefill:
self.cuda_graph_custom_mask = torch.zeros(
(max_bs * self.max_context_len),
(max_num_tokens * self.max_context_len),
dtype=torch.uint8,
device=self.device,
)
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
if self.sliding_window_size is not None and self.sliding_window_size > 0:
if kv_indices_buf is None:
self.cuda_graph_window_kv_indices = torch.zeros(
(max_bs * self.sliding_window_size),
(max_num_tokens * self.sliding_window_size),
dtype=torch.int32,
device=self.device,
)
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
self.cuda_graph_window_num_kv_splits = torch.full(
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
(max_num_tokens,),
self.max_kv_splits,
dtype=torch.int32,
device=self.device,
)
def init_forward_metadata_capture_cuda_graph(
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
mask_indptr = self.mask_indptr[: bs + 1]
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
self.common_template(forward_batch, kv_indices, call_fn)
def init_cuda_graph_state(self, max_bs: int):
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.cuda_graph_kv_indices = torch.zeros(
(self.speculative_num_steps, max_bs * self.max_context_len),
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
dtype=torch.int32,
device=self.device,
)
for i in range(self.speculative_num_steps):
self.attn_backends[i].init_cuda_graph_state(
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
)
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):