[doc] add note for get_num_kv_splits in triton_backend (#6444)

This commit is contained in:
JieXin Liang
2025-05-20 12:40:21 +08:00
committed by GitHub
parent 32cc66efa5
commit 69af3ec35f

View File

@@ -155,6 +155,9 @@ class TritonAttnBackend(AttentionBackend):
seq_lens: torch.Tensor,
):
num_token, num_seq = num_kv_splits.shape[0], seq_lens.shape[0]
# NOTE(alcanderian): Considering speculative_decodeing,
# num_kv_splits.shape[0] will be topk * real_num_token.
# And the real_num_token is num_seq in decoding phase.
num_group = num_token // num_seq
assert (