kernel: support slightly faster merge_state_v2 cuda kernel (#5381)

This commit is contained in:
DefTruth
2025-04-15 12:28:23 +08:00
committed by GitHub
parent 11421a3f44
commit 388e15c0db
7 changed files with 638 additions and 4 deletions

View File

@@ -16,6 +16,7 @@ from sgl_kernel.attention import (
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_state,
merge_state_v2,
)
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace,

View File

@@ -1,4 +1,4 @@
from typing import Tuple
from typing import Optional, Tuple
import torch
@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
v_merged = torch.empty_like(v_a)
s_merged = torch.empty_like(s_a)
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def merge_state_v2(
v_a: torch.Tensor,
s_a: torch.Tensor,
v_b: torch.Tensor,
s_b: torch.Tensor,
v_merged: Optional[torch.Tensor] = None,
s_merged: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
# does not support the FP8 data type and non - CUDA devices.
# It may be necessary to fall back to using the Triton kernel.
# Avoid creating new tensors if they are already provided
if v_merged is None:
v_merged = torch.empty_like(v_a)
if s_merged is None:
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def cutlass_mla_decode(
q_nope_and_q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,