ROCm: sgl-kernel enablement starting with sgl_moe_align_block (#3287)
This commit is contained in:
@@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \
|
|||||||
echo "Using ${SGL_BRANCH} branch."; \
|
echo "Using ${SGL_BRANCH} branch."; \
|
||||||
git checkout ${SGL_BRANCH}; \
|
git checkout ${SGL_BRANCH}; \
|
||||||
fi \
|
fi \
|
||||||
|
&& cd sgl-kernel \
|
||||||
|
&& python setup_rocm.py install \
|
||||||
|
&& cd .. \
|
||||||
&& if [ "$BUILD_TYPE" = "srt" ]; then \
|
&& if [ "$BUILD_TYPE" = "srt" ]; then \
|
||||||
python -m pip --no-cache-dir install -e "python[srt_hip]"; \
|
python -m pip --no-cache-dir install -e "python[srt_hip]"; \
|
||||||
else \
|
else \
|
||||||
|
|||||||
@@ -32,7 +32,9 @@ git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git
|
|||||||
cd sglang
|
cd sglang
|
||||||
|
|
||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install sgl-kernel --force-reinstall --no-deps
|
cd sgl-kernel
|
||||||
|
python setup_rocm.py install
|
||||||
|
cd ..
|
||||||
pip install -e "python[all_hip]"
|
pip install -e "python[all_hip]"
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ srt = [
|
|||||||
|
|
||||||
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
# HIP (Heterogeneous-computing Interface for Portability) for AMD
|
||||||
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
# => base docker rocm/vllm-dev:20241022, not from public vllm whl
|
||||||
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"]
|
srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11", "sgl-kernel>=0.0.3.post1"]
|
||||||
# xpu is not enabled in public vllm and torch whl,
|
# xpu is not enabled in public vllm and torch whl,
|
||||||
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
|
||||||
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"]
|
srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"]
|
||||||
|
|||||||
@@ -15,18 +15,10 @@ from vllm import _custom_ops as ops
|
|||||||
|
|
||||||
from sglang.srt.layers.moe.topk import select_experts
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
||||||
from sglang.srt.utils import (
|
from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip
|
||||||
direct_register_custom_op,
|
|
||||||
get_device_name,
|
|
||||||
is_cuda_available,
|
|
||||||
is_hip,
|
|
||||||
)
|
|
||||||
|
|
||||||
is_cuda = is_cuda_available()
|
|
||||||
is_hip_flag = is_hip()
|
is_hip_flag = is_hip()
|
||||||
if is_cuda:
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
||||||
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
|
||||||
@@ -415,7 +407,7 @@ def moe_align_block_size(
|
|||||||
)
|
)
|
||||||
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
|
||||||
if num_experts >= 224:
|
if num_experts >= 224:
|
||||||
if enable_moe_align_block_size_triton or is_hip_flag:
|
if enable_moe_align_block_size_triton:
|
||||||
moe_align_block_size_triton(
|
moe_align_block_size_triton(
|
||||||
topk_ids,
|
topk_ids,
|
||||||
num_experts,
|
num_experts,
|
||||||
|
|||||||
92
sgl-kernel/setup_rocm.py
Normal file
92
sgl-kernel/setup_rocm.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
import multiprocessing
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from setuptools import find_packages, setup
|
||||||
|
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||||
|
|
||||||
|
root = Path(__file__).parent.resolve()
|
||||||
|
|
||||||
|
if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv:
|
||||||
|
sys.argv.extend(["--plat-name", "manylinux2014_x86_64"])
|
||||||
|
|
||||||
|
|
||||||
|
def _get_version():
|
||||||
|
with open(root / "pyproject.toml") as f:
|
||||||
|
for line in f:
|
||||||
|
if line.startswith("version"):
|
||||||
|
return line.split("=")[1].strip().strip('"')
|
||||||
|
|
||||||
|
|
||||||
|
operator_namespace = "sgl_kernels"
|
||||||
|
include_dirs = [
|
||||||
|
root / "src" / "sgl-kernel" / "include",
|
||||||
|
root / "src" / "sgl-kernel" / "csrc",
|
||||||
|
]
|
||||||
|
|
||||||
|
sources = [
|
||||||
|
"src/sgl-kernel/torch_extension_rocm.cc",
|
||||||
|
"src/sgl-kernel/csrc/moe_align_kernel.cu",
|
||||||
|
]
|
||||||
|
|
||||||
|
cxx_flags = ["-O3"]
|
||||||
|
libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"]
|
||||||
|
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"]
|
||||||
|
|
||||||
|
hipcc_flags = [
|
||||||
|
"-DNDEBUG",
|
||||||
|
f"-DOPERATOR_NAMESPACE={operator_namespace}",
|
||||||
|
"-O3",
|
||||||
|
"-Xcompiler",
|
||||||
|
"-fPIC",
|
||||||
|
"-std=c++17",
|
||||||
|
"-D__HIP_PLATFORM_AMD__=1",
|
||||||
|
"--amdgpu-target=gfx942",
|
||||||
|
"-DENABLE_BF16",
|
||||||
|
"-DENABLE_FP8",
|
||||||
|
]
|
||||||
|
|
||||||
|
setup(
|
||||||
|
name="sgl-kernel",
|
||||||
|
version=_get_version(),
|
||||||
|
packages=find_packages(),
|
||||||
|
package_dir={"": "src"},
|
||||||
|
ext_modules=[
|
||||||
|
CUDAExtension(
|
||||||
|
name="sgl_kernel.ops._kernels",
|
||||||
|
sources=sources,
|
||||||
|
include_dirs=include_dirs,
|
||||||
|
extra_compile_args={
|
||||||
|
"nvcc": hipcc_flags,
|
||||||
|
"cxx": cxx_flags,
|
||||||
|
},
|
||||||
|
libraries=libraries,
|
||||||
|
extra_link_args=extra_link_args,
|
||||||
|
py_limited_api=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
cmdclass={
|
||||||
|
"build_ext": BuildExtension.with_options(
|
||||||
|
use_ninja=True, max_jobs=multiprocessing.cpu_count()
|
||||||
|
)
|
||||||
|
},
|
||||||
|
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||||
|
install_requires=["torch"],
|
||||||
|
)
|
||||||
29
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
Normal file
29
sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
/* Copyright 2025 SGLang Team. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include <ATen/core/dispatch/Dispatcher.h>
|
||||||
|
#include <torch/library.h>
|
||||||
|
|
||||||
|
#include "sgl_kernels_ops.h"
|
||||||
|
|
||||||
|
TORCH_LIBRARY_EXPAND(sgl_kernels, m) {
|
||||||
|
// moe_align_block_size
|
||||||
|
m.def(
|
||||||
|
"moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! "
|
||||||
|
"experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()");
|
||||||
|
m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size);
|
||||||
|
}
|
||||||
|
|
||||||
|
REGISTER_EXTENSION(_kernels)
|
||||||
Reference in New Issue
Block a user