kernel: support slightly faster merge_state_v2 cuda kernel (#5381)

This commit is contained in:
DefTruth
2025-04-15 12:28:23 +08:00
committed by GitHub
parent 11421a3f44
commit 388e15c0db
7 changed files with 638 additions and 4 deletions

View File

@@ -89,6 +89,8 @@ void lightning_attention_decode(
torch::Tensor new_kv);
void merge_state(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void merge_state_v2(
at::Tensor v_a, at::Tensor s_a, at::Tensor v_b, at::Tensor s_b, at::Tensor v_merged, at::Tensor s_merged);
void cutlass_mla_decode(
torch::Tensor const& out,
torch::Tensor const& q_nope_and_q_pe,