[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
This commit is contained in:
@@ -4,9 +4,9 @@ from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
|
||||
if is_cuda_available():
|
||||
if is_cuda_available() or is_hip():
|
||||
from sgl_kernel import (
|
||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||
from sglang.srt.utils import is_cuda_available
|
||||
from sglang.srt.utils import is_cuda_available, is_hip
|
||||
|
||||
if is_cuda_available():
|
||||
from sgl_kernel import (
|
||||
@@ -23,6 +23,8 @@ if is_cuda_available():
|
||||
tree_speculative_sampling_target_only,
|
||||
verify_tree_greedy,
|
||||
)
|
||||
elif is_hip():
|
||||
from sgl_kernel import verify_tree_greedy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||
|
||||
@@ -17,7 +17,11 @@
|
||||
#include <ATen/ATen.h>
|
||||
#include <ATen/cuda/CUDAContext.h>
|
||||
|
||||
#ifndef USE_ROCM
|
||||
#include "pytorch_extension_utils.h"
|
||||
#else
|
||||
#include "pytorch_extension_utils_rocm.h"
|
||||
#endif
|
||||
|
||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||
// selected_index [bs, draft_token_num - 1]
|
||||
|
||||
20
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
Normal file
20
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#include <torch/library.h>
|
||||
|
||||
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||
|
||||
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||
|
||||
#define CHECK_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_CONTIGUOUS(x)
|
||||
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||
CHECK_CUDA(x); \
|
||||
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||
|
||||
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||
|
||||
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
|
||||
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||
@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||
"token_expert_indices, Tensor gating_output) -> ()");
|
||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||
|
||||
m.def(
|
||||
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||
"Tensor target_predict, int cuda_stream) -> ()");
|
||||
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||
|
||||
m.def(
|
||||
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||
}
|
||||
|
||||
REGISTER_EXTENSION(common_ops)
|
||||
|
||||
@@ -43,6 +43,7 @@ sources = [
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||
"csrc/torch_extension_rocm.cc",
|
||||
"csrc/speculative/eagle_utils.cu",
|
||||
]
|
||||
|
||||
cxx_flags = ["-O3"]
|
||||
|
||||
@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
||||
cls.model = "lmsys/sglang-ci-dsv3-test"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
other_args = ["--trust-remote-code"]
|
||||
if torch.cuda.is_available() and torch.version.cuda:
|
||||
if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip):
|
||||
other_args.extend(
|
||||
[
|
||||
"--cuda-graph-max-bs",
|
||||
|
||||
Reference in New Issue
Block a user