feat: adapt merge_state (#5337)

This commit is contained in:
Yineng Zhang
2025-04-12 21:14:04 -07:00
committed by GitHub
parent 7d3b7c87f5
commit b62e7e99b8
8 changed files with 224 additions and 3 deletions

View File

@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
cutlass_mla_decode,
cutlass_mla_get_workspace_size,
lightning_attention_decode,
merge_state,
)
from sgl_kernel.elementwise import (
apply_rope_with_cos_sin_cache_inplace,

View File

@@ -1,3 +1,5 @@
from typing import Tuple
import torch
@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
)
def merge_state(
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
s_a = s_a.to(torch.float32)
s_b = s_b.to(torch.float32)
v_merged = torch.empty_like(v_a)
s_merged = torch.empty_like(s_a)
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
return v_merged, s_merged
def cutlass_mla_decode(
q_nope_and_q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
@@ -54,7 +67,7 @@ def cutlass_mla_decode(
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
)
torch.ops.sgl_kernel.cutlass_mla_decode(
torch.ops.sgl_kernel.cutlass_mla_decode.default(
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
)
return out
@@ -63,6 +76,6 @@ def cutlass_mla_decode(
def cutlass_mla_get_workspace_size(
max_seq_len: int, num_batches: int, sm_count: int = 0
) -> int:
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size(
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
max_seq_len, num_batches, sm_count
)