feat: mtp support dp-attention (#6081)
Co-authored-by: austindeng <austindeng@tencent.com> Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com> Co-authored-by: Qiaolin Yu <liin1211@outlook.com> Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
@@ -262,11 +262,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
)
|
||||
|
||||
def init_cuda_graph_state(
|
||||
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
|
||||
self,
|
||||
max_bs: int,
|
||||
max_num_tokens: int,
|
||||
kv_indices_buf: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if kv_indices_buf is None:
|
||||
cuda_graph_kv_indices = torch.zeros(
|
||||
(max_bs * self.max_context_len,),
|
||||
(max_num_tokens * self.max_context_len,),
|
||||
dtype=torch.int32,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -285,7 +288,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
||||
|
||||
if not self.skip_prefill:
|
||||
self.cuda_graph_custom_mask = torch.zeros(
|
||||
(max_bs * self.max_context_len),
|
||||
(max_num_tokens * self.max_context_len),
|
||||
dtype=torch.uint8,
|
||||
device="cuda",
|
||||
)
|
||||
@@ -1096,7 +1099,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
self.common_template(forward_batch, kv_indices, call_fn)
|
||||
|
||||
def init_cuda_graph_state(self, max_bs: int):
|
||||
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
||||
self.cuda_graph_kv_indices = torch.zeros(
|
||||
(self.speculative_num_steps, max_bs * self.max_context_len),
|
||||
dtype=torch.int32,
|
||||
@@ -1105,7 +1108,7 @@ class FlashInferMultiStepDraftBackend:
|
||||
|
||||
for i in range(self.speculative_num_steps):
|
||||
self.attn_backends[i].init_cuda_graph_state(
|
||||
max_bs, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i]
|
||||
)
|
||||
|
||||
def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch):
|
||||
|
||||
Reference in New Issue
Block a user