[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
This commit is contained in:
@@ -4,9 +4,9 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda_available() or is_hip():
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import (
|
||||
@@ -23,6 +23,8 @@ if is_cuda_available():
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
Reference in New Issue
Block a user