Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)

This commit is contained in:
fzyzcjy
2025-08-12 16:46:40 +08:00
committed by GitHub
parent fcc11e5ed5
commit 9aea255522
11 changed files with 1152 additions and 194 deletions

View File

@@ -1,4 +1,5 @@
from typing import Optional
from dataclasses import dataclass
from typing import Any, Optional
import torch
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
return out
@dataclass
class FusedSetKVBufferArg:
"""
value : Optional[torch.Tensor]
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
k_buffer : Optional[torch.Tensor]
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
v_buffer : Optional[torch.Tensor]
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
k_scale : Optional[float]
Scale factor for keys.
v_scale : Optional[float]
Scale factor for values.
cache_loc : Optional[torch.Tensor]
Cache location tensor, used for indexing kv cache.
"""
value: torch.Tensor
k_buffer: torch.Tensor
v_buffer: torch.Tensor
k_scale: Optional[float]
v_scale: Optional[float]
cache_loc: torch.Tensor
def apply_rope_with_cos_sin_cache_inplace(
positions: torch.Tensor,
query: torch.Tensor,
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
head_size: int,
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
fused_set_kv_buffer_arg : FusedSetKVBufferArg
Fuse the set-kv-buffer operation into this kernel
Note
----
The rotary dimension is determined by the cosine cache and sine cache.
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
def _view_3d(x):
return x.view(x.shape[0], -1, head_size)
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
query.view(query.shape[0], -1, head_size),
key.view(key.shape[0], -1, head_size),
_view_3d(query),
_view_3d(key),
_view_3d(query),
_view_3d(key),
cos_sin_cache,
positions.long(),
(not is_neox),
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.k_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
_view_3d(fused_set_kv_buffer_arg.v_buffer)
if fused_set_kv_buffer_arg is not None
else None
),
(
fused_set_kv_buffer_arg.cache_loc
if fused_set_kv_buffer_arg is not None
else None
),
)