kernel: support slightly faster merge_state_v2 cuda kernel (#5381)
This commit is contained in:
@@ -16,6 +16,7 @@ from sgl_kernel.attention import (
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
merge_state_v2,
|
||||
)
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -10,16 +10,47 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
v_merged = torch.empty_like(v_a)
|
||||
s_merged = torch.empty_like(s_a)
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def merge_state_v2(
|
||||
v_a: torch.Tensor,
|
||||
s_a: torch.Tensor,
|
||||
v_b: torch.Tensor,
|
||||
s_b: torch.Tensor,
|
||||
v_merged: Optional[torch.Tensor] = None,
|
||||
s_merged: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
# TODO(DefTruth): Currently, the custom merge_attn_states kernel
|
||||
# does not support the FP8 data type and non - CUDA devices.
|
||||
# It may be necessary to fall back to using the Triton kernel.
|
||||
|
||||
# Avoid creating new tensors if they are already provided
|
||||
if v_merged is None:
|
||||
v_merged = torch.empty_like(v_a)
|
||||
if s_merged is None:
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state_v2.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope_and_q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
|
||||
Reference in New Issue
Block a user