feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -324,7 +324,10 @@ class AiterAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
|
||||
if kv_indices_buf is None:
|
||||
@@ -338,7 +341,7 @@ class AiterAttnBackend(AttentionBackend):
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
@@ -19,7 +19,7 @@ class AttentionBackend(ABC):
|
||||
"""Init the metadata for a forward pass."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
"""Init the global shared states for cuda graph."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -122,6 +122,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
||||
def init_cuda_graph_state(
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
block_kv_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if block_kv_indices is None:
|
||||
|
||||
@@ -1120,7 +1120,7 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
|
||||
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
"""Initialize CUDA graph state for the attention backend.
|
||||
|
||||
Args:
|
||||
@@ -1999,9 +1999,9 @@ class FlashAttentionMultiStepBackend:
|
||||
for i in range(self.speculative_num_steps - 1):
|
||||
self.attn_backends[i].init_forward_metadata(forward_batch)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs)
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
|
||||
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
(max_num_tokens * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
|
||||
@@ -199,7 +199,10 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
@@ -852,7 +855,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
@@ -861,7 +864,7 @@ class FlashInferMLAMultiStepDraftBackend:
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
|
||||
@@ -148,6 +148,7 @@ class FlashMLABackend(FlashInferMLAAttnBackend):
|
||||
def init_cuda_graph_state(
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
block_kv_indices: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if block_kv_indices is None:
|
||||
@@ -502,9 +503,11 @@ class FlashMLAMultiStepDraftBackend:
|
||||
|
||||
self.common_template(forward_batch, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(max_bs, block_kv_indices=None)
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, max_num_tokens, block_kv_indices=None
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
def call_fn(i, forward_batch):
|
||||
|
||||
@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
|
||||
if forward_batch_child.batch_size > 0:
|
||||
child.init_forward_metadata(forward_batch=forward_batch_child)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
self.primary.init_cuda_graph_state(max_bs=max_bs)
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
||||
for item in self.children:
|
||||
# TODO for children, maybe can provide *smaller* max_bs to optimize
|
||||
item.init_cuda_graph_state(max_bs=max_bs)
|
||||
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
self,
|
||||
|
||||
@@ -261,6 +261,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
num_kv_splits = None
|
||||
attn_logits = None
|
||||
attn_lse = None
|
||||
|
||||
elif forward_batch.forward_mode.is_draft_extend():
|
||||
kv_indices, kv_indptr, qo_indptr, custom_mask = (
|
||||
spec_info.generate_attn_arg_prefill(
|
||||
@@ -335,24 +336,27 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
self.cuda_graph_attn_logits = torch.zeros(
|
||||
(max_bs, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||
(max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_attn_lse = torch.zeros(
|
||||
(max_bs, self.num_head, self.max_kv_splits),
|
||||
(max_num_tokens, self.num_head, self.max_kv_splits),
|
||||
dtype=torch.float32,
|
||||
device=self.device,
|
||||
)
|
||||
self.cuda_graph_num_kv_splits = torch.full(
|
||||
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||
(max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||
)
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -361,7 +365,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -369,7 +373,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
if self.sliding_window_size is not None and self.sliding_window_size > 0:
|
||||
if kv_indices_buf is None:
|
||||
self.cuda_graph_window_kv_indices = torch.zeros(
|
||||
(max_bs * self.sliding_window_size),
|
||||
(max_num_tokens * self.sliding_window_size),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
@@ -377,7 +381,10 @@ class TritonAttnBackend(AttentionBackend):
|
||||
self.cuda_graph_window_kv_indices = torch.zeros_like(kv_indices_buf)
|
||||
|
||||
self.cuda_graph_window_num_kv_splits = torch.full(
|
||||
(max_bs,), self.max_kv_splits, dtype=torch.int32, device=self.device
|
||||
(max_num_tokens,),
|
||||
self.max_kv_splits,
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(
|
||||
@@ -458,6 +465,7 @@ class TritonAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
custom_mask = self.cuda_graph_custom_mask
|
||||
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
|
||||
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
|
||||
mask_indptr = self.mask_indptr[: bs + 1]
|
||||
mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0)
|
||||
@@ -821,15 +829,15 @@ class TritonMultiStepDraftBackend:
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
(self.speculative_num_steps, max_num_tokens * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
|
||||
Reference in New Issue
Block a user