[Feature] Support custom set kv buffer kernel (#8884)

This commit is contained in:
DarkSharpness
2025-08-12 16:56:51 -07:00
committed by GitHub
parent 0edda32001
commit 86a0be65d8
6 changed files with 178 additions and 0 deletions

View File

@@ -67,6 +67,7 @@ from sgl_kernel.marlin import (
awq_marlin_repack,
gptq_marlin_repack,
)
from sgl_kernel.memory import set_kv_buffer_kernel
from sgl_kernel.moe import (
apply_shuffle_mul_sum,
cutlass_fp4_group_mm,

View File

@@ -0,0 +1,18 @@
import torch
def set_kv_buffer_kernel(
k_cache: torch.Tensor,
v_cache: torch.Tensor,
loc: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
fallback: bool = False,
):
try:
if fallback:
raise RuntimeError("Fallback to torch implementation")
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
except RuntimeError: # ok, fallback to torch implementation
k_cache[loc] = k
v_cache[loc] = v