Add kernels to optimize RoPE and the decoding stage (#143)
Co-authored-by: chengxiaokang <chengxiaokang@baidu.com>
This commit is contained in:
@@ -2235,3 +2235,40 @@ def _fake_fwd_kvcache_mla(
|
||||
return None
|
||||
|
||||
fwd_kvcache_mla.register_fake(_fake_fwd_kvcache_mla)
|
||||
|
||||
##################################################
|
||||
# ------------------ fast_topkv2 -------------
|
||||
##################################################
|
||||
@custom_op("_C::fast_topkv2", mutates_args=())
|
||||
def fast_topkv2(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
return topk_indices
|
||||
|
||||
@impl("_C::fast_topkv2", "CUDA")
|
||||
def fast_topkv2_cuda(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = xtorch_ops.fast_topkv2(
|
||||
score=score,
|
||||
lengths=lengths,
|
||||
topk=topk)
|
||||
return topk_indices
|
||||
|
||||
def _fake_fast_topkv2(
|
||||
score: torch.Tensor,
|
||||
lengths: torch.Tensor,
|
||||
topk: Optional[int] = 2048) -> torch.Tensor:
|
||||
assert topk == 2048, "fast_topkv2 only supports topk = 2048 by now"
|
||||
topk_indices = score.new_empty((score.size(0), topk), dtype=torch.int32)
|
||||
return topk_indices
|
||||
|
||||
fast_topkv2.register_fake(_fake_fast_topkv2)
|
||||
Reference in New Issue
Block a user