From af6535e7aaf5c1e9352149f0edfde37d977cd473 Mon Sep 17 00:00:00 2001 From: Alex Sun Date: Mon, 24 Mar 2025 13:58:05 +0800 Subject: [PATCH] [ROCm] Enable MTP (NextN) on AMD GPU (#4631) --- .../srt/speculative/build_eagle_tree.py | 4 ++-- python/sglang/srt/speculative/eagle_utils.py | 4 +++- sgl-kernel/csrc/speculative/eagle_utils.cu | 4 ++++ .../pytorch_extension_utils_rocm.h | 20 +++++++++++++++++++ sgl-kernel/csrc/torch_extension_rocm.cc | 12 +++++++++++ sgl-kernel/setup_rocm.py | 1 + test/srt/test_mla_deepseek_v3.py | 2 +- 7 files changed, 43 insertions(+), 4 deletions(-) create mode 100644 sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h diff --git a/python/sglang/srt/speculative/build_eagle_tree.py b/python/sglang/srt/speculative/build_eagle_tree.py index b26d2c2e2..364ca0677 100644 --- a/python/sglang/srt/speculative/build_eagle_tree.py +++ b/python/sglang/srt/speculative/build_eagle_tree.py @@ -4,9 +4,9 @@ from typing import List import torch -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import is_cuda_available, is_hip -if is_cuda_available(): +if is_cuda_available() or is_hip(): from sgl_kernel import ( build_tree_kernel_efficient as sgl_build_tree_kernel_efficient, ) diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 3dc2a9699..0c5b9b4a5 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient -from sglang.srt.utils import is_cuda_available +from sglang.srt.utils import is_cuda_available, is_hip if is_cuda_available(): from sgl_kernel import ( @@ -23,6 +23,8 @@ if is_cuda_available(): tree_speculative_sampling_target_only, verify_tree_greedy, ) +elif is_hip(): + from sgl_kernel import verify_tree_greedy if TYPE_CHECKING: from sglang.srt.managers.schedule_batch import ScheduleBatch diff --git a/sgl-kernel/csrc/speculative/eagle_utils.cu b/sgl-kernel/csrc/speculative/eagle_utils.cu index 968a8a264..aeb6b8421 100644 --- a/sgl-kernel/csrc/speculative/eagle_utils.cu +++ b/sgl-kernel/csrc/speculative/eagle_utils.cu @@ -17,7 +17,11 @@ #include #include +#ifndef USE_ROCM #include "pytorch_extension_utils.h" +#else +#include "pytorch_extension_utils_rocm.h" +#endif // parent_list [bs, topk * (depth - 1) + 1)] // selected_index [bs, draft_token_num - 1] diff --git a/sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h b/sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h new file mode 100644 index 000000000..fa5fd129f --- /dev/null +++ b/sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h @@ -0,0 +1,20 @@ +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor") + +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") +#define CHECK_LAST_DIM_CONTIGUOUS(x) \ + TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension") + +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) +#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_LAST_DIM_CONTIGUOUS(x) + +#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor") + +#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b) + +#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b) diff --git a/sgl-kernel/csrc/torch_extension_rocm.cc b/sgl-kernel/csrc/torch_extension_rocm.cc index d424ce6d6..6da15a384 100644 --- a/sgl-kernel/csrc/torch_extension_rocm.cc +++ b/sgl-kernel/csrc/torch_extension_rocm.cc @@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) { "topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! " "token_expert_indices, Tensor gating_output) -> ()"); m.impl("topk_softmax", torch::kCUDA, &topk_softmax); + + m.def( + "verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, " + "Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, " + "Tensor target_predict, int cuda_stream) -> ()"); + m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy); + + m.def( + "build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, " + "Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, " + "Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()"); + m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient); } REGISTER_EXTENSION(common_ops) diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py index a9cc5edca..b147e6b53 100644 --- a/sgl-kernel/setup_rocm.py +++ b/sgl-kernel/setup_rocm.py @@ -43,6 +43,7 @@ sources = [ "csrc/moe/moe_align_kernel.cu", "csrc/moe/moe_topk_softmax_kernels.cu", "csrc/torch_extension_rocm.cc", + "csrc/speculative/eagle_utils.cu", ] cxx_flags = ["-O3"] diff --git a/test/srt/test_mla_deepseek_v3.py b/test/srt/test_mla_deepseek_v3.py index bb304ed29..42a7df59b 100644 --- a/test/srt/test_mla_deepseek_v3.py +++ b/test/srt/test_mla_deepseek_v3.py @@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase): cls.model = "lmsys/sglang-ci-dsv3-test" cls.base_url = DEFAULT_URL_FOR_TEST other_args = ["--trust-remote-code"] - if torch.cuda.is_available() and torch.version.cuda: + if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip): other_args.extend( [ "--cuda-graph-max-bs",