[Feature] Support custom set kv buffer kernel (#8884)
This commit is contained in:
@@ -67,6 +67,7 @@ from sgl_kernel.marlin import (
|
||||
awq_marlin_repack,
|
||||
gptq_marlin_repack,
|
||||
)
|
||||
from sgl_kernel.memory import set_kv_buffer_kernel
|
||||
from sgl_kernel.moe import (
|
||||
apply_shuffle_mul_sum,
|
||||
cutlass_fp4_group_mm,
|
||||
|
||||
18
sgl-kernel/python/sgl_kernel/memory.py
Normal file
18
sgl-kernel/python/sgl_kernel/memory.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import torch
|
||||
|
||||
|
||||
def set_kv_buffer_kernel(
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
loc: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
fallback: bool = False,
|
||||
):
|
||||
try:
|
||||
if fallback:
|
||||
raise RuntimeError("Fallback to torch implementation")
|
||||
torch.ops.sgl_kernel.store_kv_cache(k_cache, v_cache, loc, k, v)
|
||||
except RuntimeError: # ok, fallback to torch implementation
|
||||
k_cache[loc] = k
|
||||
v_cache[loc] = v
|
||||
Reference in New Issue
Block a user