Fix sampling for speculative decoding & simplify kernels (#7207)
This commit is contained in:
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
11
sgl-kernel/python/sgl_kernel/top_k.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import torch
|
||||
|
||||
|
||||
def fast_topk(values, topk, dim):
|
||||
if topk == 1:
|
||||
# Use max along the specified dimension to get both value and index
|
||||
return torch.max(values, dim=dim, keepdim=True)
|
||||
else:
|
||||
# Use topk for efficiency with larger k values
|
||||
# TODO: implement faster cuda kernels for large vocab sizes
|
||||
return torch.topk(values, topk, dim=dim)
|
||||
Reference in New Issue
Block a user