[ROCm] Enable MTP (NextN) on AMD GPU (#4631)

This commit is contained in:
Alex Sun
2025-03-24 13:58:05 +08:00
committed by GitHub
parent 93cf7fc5cd
commit af6535e7aa
7 changed files with 43 additions and 4 deletions

View File

@@ -4,9 +4,9 @@ from typing import List
import torch
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda_available, is_hip
if is_cuda_available():
if is_cuda_available() or is_hip():
from sgl_kernel import (
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
)

View File

@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
from sglang.srt.utils import is_cuda_available
from sglang.srt.utils import is_cuda_available, is_hip
if is_cuda_available():
from sgl_kernel import (
@@ -23,6 +23,8 @@ if is_cuda_available():
tree_speculative_sampling_target_only,
verify_tree_greedy,
)
elif is_hip():
from sgl_kernel import verify_tree_greedy
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch