feat: adapt merge_state (#5337)
This commit is contained in:
@@ -15,6 +15,7 @@ from sgl_kernel.attention import (
|
||||
cutlass_mla_decode,
|
||||
cutlass_mla_get_workspace_size,
|
||||
lightning_attention_decode,
|
||||
merge_state,
|
||||
)
|
||||
from sgl_kernel.elementwise import (
|
||||
apply_rope_with_cos_sin_cache_inplace,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
@@ -7,6 +9,17 @@ def lightning_attention_decode(q, k, v, past_kv, slope, output, new_kv):
|
||||
)
|
||||
|
||||
|
||||
def merge_state(
|
||||
v_a: torch.Tensor, s_a: torch.Tensor, v_b: torch.Tensor, s_b: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
s_a = s_a.to(torch.float32)
|
||||
s_b = s_b.to(torch.float32)
|
||||
v_merged = torch.empty_like(v_a)
|
||||
s_merged = torch.empty_like(s_a)
|
||||
torch.ops.sgl_kernel.merge_state.default(v_a, s_a, v_b, s_b, v_merged, s_merged)
|
||||
return v_merged, s_merged
|
||||
|
||||
|
||||
def cutlass_mla_decode(
|
||||
q_nope_and_q_pe: torch.Tensor,
|
||||
kv_c_and_k_pe_cache: torch.Tensor,
|
||||
@@ -54,7 +67,7 @@ def cutlass_mla_decode(
|
||||
(B_q, H, D_latent), device=q_nope_and_q_pe.device, dtype=q_nope_and_q_pe.dtype
|
||||
)
|
||||
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode(
|
||||
torch.ops.sgl_kernel.cutlass_mla_decode.default(
|
||||
out, q_nope_and_q_pe, kv_c_and_k_pe_cache, seq_lens, page_table, workspace
|
||||
)
|
||||
return out
|
||||
@@ -63,6 +76,6 @@ def cutlass_mla_decode(
|
||||
def cutlass_mla_get_workspace_size(
|
||||
max_seq_len: int, num_batches: int, sm_count: int = 0
|
||||
) -> int:
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size(
|
||||
return torch.ops.sgl_kernel.cutlass_mla_get_workspace_size.default(
|
||||
max_seq_len, num_batches, sm_count
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user