From 2c1a695ff111cadc200cf97d4c2cbfe95ebecb70 Mon Sep 17 00:00:00 2001 From: HAI Date: Tue, 4 Feb 2025 05:44:44 -0800 Subject: [PATCH] ROCm: sgl-kernel enablement starting with sgl_moe_align_block (#3287) --- docker/Dockerfile.rocm | 3 + docs/start/install.md | 4 +- python/pyproject.toml | 2 +- .../layers/moe/fused_moe_triton/fused_moe.py | 14 +-- sgl-kernel/setup_rocm.py | 92 +++++++++++++++++++ .../src/sgl-kernel/torch_extension_rocm.cc | 29 ++++++ 6 files changed, 131 insertions(+), 13 deletions(-) create mode 100644 sgl-kernel/setup_rocm.py create mode 100644 sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc diff --git a/docker/Dockerfile.rocm b/docker/Dockerfile.rocm index caa4666c8..480f80854 100644 --- a/docker/Dockerfile.rocm +++ b/docker/Dockerfile.rocm @@ -28,6 +28,9 @@ RUN git clone ${SGL_REPO} \ echo "Using ${SGL_BRANCH} branch."; \ git checkout ${SGL_BRANCH}; \ fi \ + && cd sgl-kernel \ + && python setup_rocm.py install \ + && cd .. \ && if [ "$BUILD_TYPE" = "srt" ]; then \ python -m pip --no-cache-dir install -e "python[srt_hip]"; \ else \ diff --git a/docs/start/install.md b/docs/start/install.md index b9702f021..fc1a936c6 100644 --- a/docs/start/install.md +++ b/docs/start/install.md @@ -32,7 +32,9 @@ git clone -b v0.4.2.post1 https://github.com/sgl-project/sglang.git cd sglang pip install --upgrade pip -pip install sgl-kernel --force-reinstall --no-deps +cd sgl-kernel +python setup_rocm.py install +cd .. pip install -e "python[all_hip]" ``` diff --git a/python/pyproject.toml b/python/pyproject.toml index c600ffc0d..f87a2702b 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -31,7 +31,7 @@ srt = [ # HIP (Heterogeneous-computing Interface for Portability) for AMD # => base docker rocm/vllm-dev:20241022, not from public vllm whl -srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11"] +srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.7.dev2", "outlines==0.1.11", "sgl-kernel>=0.0.3.post1"] # xpu is not enabled in public vllm and torch whl, # need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm srt_xpu = ["sglang[runtime_common]", "outlines>=0.0.44,<0.1.0"] diff --git a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py index 32c8fcbb6..fab71809b 100644 --- a/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +++ b/python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py @@ -15,18 +15,10 @@ from vllm import _custom_ops as ops from sglang.srt.layers.moe.topk import select_experts from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8 -from sglang.srt.utils import ( - direct_register_custom_op, - get_device_name, - is_cuda_available, - is_hip, -) +from sglang.srt.utils import direct_register_custom_op, get_device_name, is_hip -is_cuda = is_cuda_available() is_hip_flag = is_hip() -if is_cuda: - from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size - +from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size logger = logging.getLogger(__name__) padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0 @@ -415,7 +407,7 @@ def moe_align_block_size( ) num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device) if num_experts >= 224: - if enable_moe_align_block_size_triton or is_hip_flag: + if enable_moe_align_block_size_triton: moe_align_block_size_triton( topk_ids, num_experts, diff --git a/sgl-kernel/setup_rocm.py b/sgl-kernel/setup_rocm.py new file mode 100644 index 000000000..6530cd7c7 --- /dev/null +++ b/sgl-kernel/setup_rocm.py @@ -0,0 +1,92 @@ +# Copyright 2025 SGLang Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import multiprocessing +import os +import sys +from pathlib import Path + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension + +root = Path(__file__).parent.resolve() + +if "bdist_wheel" in sys.argv and "--plat-name" not in sys.argv: + sys.argv.extend(["--plat-name", "manylinux2014_x86_64"]) + + +def _get_version(): + with open(root / "pyproject.toml") as f: + for line in f: + if line.startswith("version"): + return line.split("=")[1].strip().strip('"') + + +operator_namespace = "sgl_kernels" +include_dirs = [ + root / "src" / "sgl-kernel" / "include", + root / "src" / "sgl-kernel" / "csrc", +] + +sources = [ + "src/sgl-kernel/torch_extension_rocm.cc", + "src/sgl-kernel/csrc/moe_align_kernel.cu", +] + +cxx_flags = ["-O3"] +libraries = ["hiprtc", "amdhip64", "c10", "torch", "torch_python"] +extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", "-L/usr/lib/x86_64-linux-gnu"] + +hipcc_flags = [ + "-DNDEBUG", + f"-DOPERATOR_NAMESPACE={operator_namespace}", + "-O3", + "-Xcompiler", + "-fPIC", + "-std=c++17", + "-D__HIP_PLATFORM_AMD__=1", + "--amdgpu-target=gfx942", + "-DENABLE_BF16", + "-DENABLE_FP8", +] + +setup( + name="sgl-kernel", + version=_get_version(), + packages=find_packages(), + package_dir={"": "src"}, + ext_modules=[ + CUDAExtension( + name="sgl_kernel.ops._kernels", + sources=sources, + include_dirs=include_dirs, + extra_compile_args={ + "nvcc": hipcc_flags, + "cxx": cxx_flags, + }, + libraries=libraries, + extra_link_args=extra_link_args, + py_limited_api=True, + ), + ], + cmdclass={ + "build_ext": BuildExtension.with_options( + use_ninja=True, max_jobs=multiprocessing.cpu_count() + ) + }, + options={"bdist_wheel": {"py_limited_api": "cp39"}}, + install_requires=["torch"], +) diff --git a/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc new file mode 100644 index 000000000..22f40da10 --- /dev/null +++ b/sgl-kernel/src/sgl-kernel/torch_extension_rocm.cc @@ -0,0 +1,29 @@ +/* Copyright 2025 SGLang Team. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "sgl_kernels_ops.h" + +TORCH_LIBRARY_EXPAND(sgl_kernels, m) { + // moe_align_block_size + m.def( + "moe_align_block_size(Tensor topk_ids, int num_experts, int block_size, Tensor! sorted_token_ids, Tensor! " + "experts_ids, Tensor! num_tokens_post_pad, Tensor! token_cnts_buffer, Tensor! cumsum_buffer) -> ()"); + m.impl("moe_align_block_size", torch::kCUDA, &moe_align_block_size); +} + +REGISTER_EXTENSION(_kernels)