Fix AMD speculative decoding (#7252)
This commit is contained in:
@@ -27,14 +27,14 @@ from sglang.srt.utils import is_cuda, is_hip, next_power_of_2
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import (
|
||||
fast_topk,
|
||||
top_k_renorm_prob,
|
||||
top_p_renorm_prob,
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
from sgl_kernel.top_k import fast_topk
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
from sgl_kernel import fast_topk, verify_tree_greedy
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
Reference in New Issue
Block a user