[ROCm] Enable MTP (NextN) on AMD GPU (#4631)
This commit is contained in:
@@ -4,9 +4,9 @@ from typing import List
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available, is_hip
|
||||||
|
|
||||||
if is_cuda_available():
|
if is_cuda_available() or is_hip():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|||||||
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
|
||||||
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode
|
||||||
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
from sglang.srt.speculative.build_eagle_tree import build_tree_kernel_efficient
|
||||||
from sglang.srt.utils import is_cuda_available
|
from sglang.srt.utils import is_cuda_available, is_hip
|
||||||
|
|
||||||
if is_cuda_available():
|
if is_cuda_available():
|
||||||
from sgl_kernel import (
|
from sgl_kernel import (
|
||||||
@@ -23,6 +23,8 @@ if is_cuda_available():
|
|||||||
tree_speculative_sampling_target_only,
|
tree_speculative_sampling_target_only,
|
||||||
verify_tree_greedy,
|
verify_tree_greedy,
|
||||||
)
|
)
|
||||||
|
elif is_hip():
|
||||||
|
from sgl_kernel import verify_tree_greedy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
||||||
|
|||||||
@@ -17,7 +17,11 @@
|
|||||||
#include <ATen/ATen.h>
|
#include <ATen/ATen.h>
|
||||||
#include <ATen/cuda/CUDAContext.h>
|
#include <ATen/cuda/CUDAContext.h>
|
||||||
|
|
||||||
|
#ifndef USE_ROCM
|
||||||
#include "pytorch_extension_utils.h"
|
#include "pytorch_extension_utils.h"
|
||||||
|
#else
|
||||||
|
#include "pytorch_extension_utils_rocm.h"
|
||||||
|
#endif
|
||||||
|
|
||||||
// parent_list [bs, topk * (depth - 1) + 1)]
|
// parent_list [bs, topk * (depth - 1) + 1)]
|
||||||
// selected_index [bs, draft_token_num - 1]
|
// selected_index [bs, draft_token_num - 1]
|
||||||
|
|||||||
20
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
Normal file
20
sgl-kernel/csrc/speculative/pytorch_extension_utils_rocm.h
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
|
||||||
|
|
||||||
|
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
|
||||||
|
#define CHECK_LAST_DIM_CONTIGUOUS(x) \
|
||||||
|
TORCH_CHECK(x.strides()[x.strides().size() - 1] == 1, #x "must be contiguous at last dimension")
|
||||||
|
|
||||||
|
#define CHECK_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_CONTIGUOUS(x)
|
||||||
|
#define CHECK_LAST_DIM_CONTIGUOUS_INPUT(x) \
|
||||||
|
CHECK_CUDA(x); \
|
||||||
|
CHECK_LAST_DIM_CONTIGUOUS(x)
|
||||||
|
|
||||||
|
#define CHECK_DIM(d, x) TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
|
||||||
|
|
||||||
|
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||||
|
|
||||||
|
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
|
||||||
@@ -65,6 +65,18 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
|
|||||||
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
"topk_softmax(Tensor! topk_weights, Tensor! topk_indices, Tensor! "
|
||||||
"token_expert_indices, Tensor gating_output) -> ()");
|
"token_expert_indices, Tensor gating_output) -> ()");
|
||||||
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
m.impl("topk_softmax", torch::kCUDA, &topk_softmax);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"verify_tree_greedy(Tensor! predicts, Tensor! accept_index, Tensor! accept_token_num, "
|
||||||
|
"Tensor candidates, Tensor retrive_index, Tensor retrive_next_token, Tensor retrive_next_sibling, "
|
||||||
|
"Tensor target_predict, int cuda_stream) -> ()");
|
||||||
|
m.impl("verify_tree_greedy", torch::kCUDA, &verify_tree_greedy);
|
||||||
|
|
||||||
|
m.def(
|
||||||
|
"build_tree_kernel_efficient(Tensor parent_list, Tensor selected_index, Tensor verified_seq_len, "
|
||||||
|
"Tensor! tree_mask, Tensor! positions, Tensor! retrive_index, Tensor! retrive_next_token, "
|
||||||
|
"Tensor! retrive_next_sibling, int topk, int depth, int draft_token_num) -> ()");
|
||||||
|
m.impl("build_tree_kernel_efficient", torch::kCUDA, &build_tree_kernel_efficient);
|
||||||
}
|
}
|
||||||
|
|
||||||
REGISTER_EXTENSION(common_ops)
|
REGISTER_EXTENSION(common_ops)
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ sources = [
|
|||||||
"csrc/moe/moe_align_kernel.cu",
|
"csrc/moe/moe_align_kernel.cu",
|
||||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||||
"csrc/torch_extension_rocm.cc",
|
"csrc/torch_extension_rocm.cc",
|
||||||
|
"csrc/speculative/eagle_utils.cu",
|
||||||
]
|
]
|
||||||
|
|
||||||
cxx_flags = ["-O3"]
|
cxx_flags = ["-O3"]
|
||||||
|
|||||||
@@ -54,7 +54,7 @@ class TestDeepseekV3MTP(unittest.TestCase):
|
|||||||
cls.model = "lmsys/sglang-ci-dsv3-test"
|
cls.model = "lmsys/sglang-ci-dsv3-test"
|
||||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||||
other_args = ["--trust-remote-code"]
|
other_args = ["--trust-remote-code"]
|
||||||
if torch.cuda.is_available() and torch.version.cuda:
|
if torch.cuda.is_available() and (torch.version.cuda or torch.version.hip):
|
||||||
other_args.extend(
|
other_args.extend(
|
||||||
[
|
[
|
||||||
"--cuda-graph-max-bs",
|
"--cuda-graph-max-bs",
|
||||||
|
|||||||
Reference in New Issue
Block a user