Add kernels to optimize RoPE and the decoding stage (#143)

Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
fromck
2026-01-23 10:29:52 +08:00
committed by GitHub
parent 9e13f23661
commit 0ce5f1a3f7
5 changed files with 74 additions and 115 deletions

View File

@@ -2235,3 +2235,40 @@ def _fake_fwd_kvcache_mla(
return None
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
##################################################
# ------------------ fast_topkv2 -------------
##################################################
@custom_op("_C::fast_topkv2", mutates_args=())
def fast_topkv2(
score: torch.Tensor,
lengths: torch.Tensor,
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2(
score=score,
lengths=lengths,
topk=topk)
return topk_indices
@impl("_C::fast_topkv2", "CUDA")
def fast_topkv2_cuda(
score: torch.Tensor,
lengths: torch.Tensor,
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = xtorch_ops.fast_topkv2(
score=score,
lengths=lengths,
topk=topk)
return topk_indices
def _fake_fast_topkv2(
score: torch.Tensor,
lengths: torch.Tensor,
topk: Optional[int] = 2048) -> torch.Tensor:
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
return topk_indices
fast_topkv2.register_fake(_fake_fast_topkv2)