feat: mtp support dp-attention (#6081)

Co-authored-by: austindeng <austindeng@tencent.com>
Co-authored-by: tianqilin.99 <tianqilin.99@bytedance.com>
Co-authored-by: Qiaolin Yu <liin1211@outlook.com>
Co-authored-by: ch-wan <cwan39@gatech.edu>
This commit is contained in:
u4lr451
2025-06-17 15:33:28 +08:00
committed by GitHub
parent 8a10c4c3d9
commit 10d60cd41b
22 changed files with 641 additions and 151 deletions

View File

@@ -32,11 +32,11 @@ class TboAttnBackend(AttentionBackend):
if forward_batch_child.batch_size > 0:
child.init_forward_metadata(forward_batch=forward_batch_child)
def init_cuda_graph_state(self, max_bs: int):
self.primary.init_cuda_graph_state(max_bs=max_bs)
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
self.primary.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
for item in self.children:
# TODO for children, maybe can provide *smaller* max_bs to optimize
item.init_cuda_graph_state(max_bs=max_bs)
item.init_cuda_graph_state(max_bs=max_bs, max_num_tokens=max_num_tokens)
def init_forward_metadata_capture_cuda_graph(
self,