Fuse writing KV buffer into rope kernel (part 1: sgl-kernel) (#9077)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from sgl_kernel.utils import get_cuda_stream, is_hopper_arch
|
||||
@@ -237,6 +238,31 @@ if torch.version.hip is not None:
|
||||
return out
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedSetKVBufferArg:
|
||||
"""
|
||||
value : Optional[torch.Tensor]
|
||||
Value tensor, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_buffer : Optional[torch.Tensor]
|
||||
Buffer for keys, shape: ``(nnz, num_k_heads * head_size)``.
|
||||
v_buffer : Optional[torch.Tensor]
|
||||
Buffer for values, shape: ``(nnz, num_v_heads * head_size)``.
|
||||
k_scale : Optional[float]
|
||||
Scale factor for keys.
|
||||
v_scale : Optional[float]
|
||||
Scale factor for values.
|
||||
cache_loc : Optional[torch.Tensor]
|
||||
Cache location tensor, used for indexing kv cache.
|
||||
"""
|
||||
|
||||
value: torch.Tensor
|
||||
k_buffer: torch.Tensor
|
||||
v_buffer: torch.Tensor
|
||||
k_scale: Optional[float]
|
||||
v_scale: Optional[float]
|
||||
cache_loc: torch.Tensor
|
||||
|
||||
|
||||
def apply_rope_with_cos_sin_cache_inplace(
|
||||
positions: torch.Tensor,
|
||||
query: torch.Tensor,
|
||||
@@ -244,6 +270,7 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
head_size: int,
|
||||
cos_sin_cache: torch.Tensor,
|
||||
is_neox: bool = True,
|
||||
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
|
||||
) -> None:
|
||||
r"""
|
||||
Apply rotary embedding to keys and queries with precomputed cos/sin values.
|
||||
@@ -270,6 +297,9 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
|
||||
* If ``False``, the last dimension of the query/key tensor is interleaved, i.e.,
|
||||
we rotate the even dimensions ``([..., ::2])`` and odd dimensions ``([..., 1::2])``.
|
||||
fused_set_kv_buffer_arg : FusedSetKVBufferArg
|
||||
Fuse the set-kv-buffer operation into this kernel
|
||||
|
||||
Note
|
||||
----
|
||||
The rotary dimension is determined by the cosine cache and sine cache.
|
||||
@@ -277,13 +307,41 @@ def apply_rope_with_cos_sin_cache_inplace(
|
||||
if cos_sin_cache.dtype != torch.float32:
|
||||
raise ValueError("cos_sin_cache should be float32")
|
||||
|
||||
if (a := fused_set_kv_buffer_arg) is not None:
|
||||
assert a.k_scale is None, "k_scale is not yet supported"
|
||||
assert a.v_scale is None, "v_scale is not yet supported"
|
||||
assert a.cache_loc.dtype == torch.int64, f"{a.cache_loc.dtype=}"
|
||||
|
||||
def _view_3d(x):
|
||||
return x.view(x.shape[0], -1, head_size)
|
||||
|
||||
torch.ops.sgl_kernel.apply_rope_pos_ids_cos_sin_cache.default(
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
query.view(query.shape[0], -1, head_size),
|
||||
key.view(key.shape[0], -1, head_size),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
_view_3d(query),
|
||||
_view_3d(key),
|
||||
cos_sin_cache,
|
||||
positions.long(),
|
||||
(not is_neox),
|
||||
get_cuda_stream(),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.value)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.k_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
_view_3d(fused_set_kv_buffer_arg.v_buffer)
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
(
|
||||
fused_set_kv_buffer_arg.cache_loc
|
||||
if fused_set_kv_buffer_arg is not None
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user