Optimize Triton decoding kernel for dynamic workload (#4553)

This commit is contained in:
JieXin Liang
2025-03-19 12:25:38 +08:00
committed by GitHub
parent 588865f0e0
commit c0e9a36c5f
7 changed files with 277 additions and 57 deletions

View File

@@ -349,6 +349,7 @@ class FlashInferAttnBackend(AttentionBackend):
def init_forward_metadata_replay_cuda_graph(
self,
bs: int,
num_kv_heads: int,
req_pool_indices: torch.Tensor,
seq_lens: torch.Tensor,
seq_lens_sum: int,
@@ -1062,6 +1063,7 @@ class FlashInferMultiStepDraftBackend:
def call_fn(i, forward_batch):
self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
bs,
-1,
forward_batch.req_pool_indices,
forward_batch.seq_lens,
seq_lens_sum=-1,