adapt to ds3.2
This commit is contained in:
@@ -4,7 +4,7 @@ import time
|
||||
|
||||
import torch
|
||||
|
||||
from sglang import ServerArgs
|
||||
from sglang.srt.server_args import ServerArgs
|
||||
from sglang.srt.managers.cache_controller import HiCacheController
|
||||
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
||||
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
||||
|
||||
@@ -127,21 +127,45 @@ class RMSNorm(CustomOp):
|
||||
return output, residual_out
|
||||
return rms_norm(x, self.weight.data, self.variance_epsilon)
|
||||
|
||||
# def forward_hip(
|
||||
# self,
|
||||
# x: torch.Tensor,
|
||||
# residual: Optional[torch.Tensor] = None,
|
||||
# ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
# if not x.is_contiguous():
|
||||
# # NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
# x = x.contiguous()
|
||||
# if residual is not None:
|
||||
# if _vllm_version < Version("0.9"):
|
||||
# fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
# return x, residual
|
||||
# else:
|
||||
# residual_out = torch.empty_like(x)
|
||||
# output = torch.empty_like(x)
|
||||
# fused_add_rms_norm(
|
||||
# output,
|
||||
# x,
|
||||
# residual_out,
|
||||
# residual,
|
||||
# self.weight.data,
|
||||
# self.variance_epsilon,
|
||||
# )
|
||||
# return output, residual_out
|
||||
# out = torch.empty_like(x)
|
||||
# rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
# return out
|
||||
def forward_hip(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
residual: Optional[torch.Tensor] = None,
|
||||
):
|
||||
if not x.is_contiguous():
|
||||
# NOTE: Remove this if aiter kernel supports discontinuous input
|
||||
x = x.contiguous()
|
||||
|
||||
if residual is not None:
|
||||
if _vllm_version < Version("0.9"):
|
||||
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
||||
return x, residual
|
||||
else:
|
||||
residual_out = torch.empty_like(x)
|
||||
try:
|
||||
output = torch.empty_like(x)
|
||||
residual_out = torch.empty_like(x)
|
||||
fused_add_rms_norm(
|
||||
output,
|
||||
x,
|
||||
@@ -151,10 +175,21 @@ class RMSNorm(CustomOp):
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return output, residual_out
|
||||
except TypeError:
|
||||
fused_add_rms_norm(
|
||||
x,
|
||||
residual,
|
||||
self.weight.data,
|
||||
self.variance_epsilon,
|
||||
)
|
||||
return x, residual
|
||||
|
||||
out = torch.empty_like(x)
|
||||
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
||||
return out
|
||||
|
||||
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
|
||||
@@ -61,7 +61,7 @@ def inplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -79,6 +79,8 @@ def inplace_fused_experts(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> None:
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -117,7 +119,7 @@ def inplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -154,7 +156,7 @@ def outplace_fused_experts(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -173,6 +175,8 @@ def outplace_fused_experts(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -211,7 +215,7 @@ def outplace_fused_experts_fake(
|
||||
topk_ids: torch.Tensor,
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -263,6 +267,13 @@ def fused_experts(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
act_id = (
|
||||
0 if (
|
||||
moe_runner_config.activation == 0
|
||||
or (isinstance(moe_runner_config.activation, str)
|
||||
and moe_runner_config.activation.lower() == "silu")
|
||||
) else 1
|
||||
)
|
||||
if moe_runner_config.inplace:
|
||||
assert not moe_runner_config.no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
@@ -273,7 +284,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
moe_runner_config.activation,
|
||||
act_id,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
@@ -301,7 +312,7 @@ def fused_experts(
|
||||
topk_ids,
|
||||
b1,
|
||||
b2,
|
||||
moe_runner_config.activation,
|
||||
act_id,
|
||||
moe_runner_config.apply_router_weight_on_input,
|
||||
use_fp8_w8a8,
|
||||
use_int8_w8a8,
|
||||
@@ -345,7 +356,7 @@ def fused_experts_impl(
|
||||
b1: Optional[torch.Tensor] = None,
|
||||
b2: Optional[torch.Tensor] = None,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
activation: int = 0,#0 silu 1 gelu
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
@@ -364,6 +375,9 @@ def fused_experts_impl(
|
||||
gemm1_alpha: Optional[float] = None,
|
||||
gemm1_limit: Optional[float] = None,
|
||||
):
|
||||
if isinstance(activation, int):
|
||||
activation = "silu" if activation == 0 else "gelu"
|
||||
|
||||
padded_size = padding_size
|
||||
if not (use_fp8_w8a8 or use_int8_w8a8) or block_shape is not None or _use_aiter:
|
||||
padded_size = 0
|
||||
|
||||
@@ -516,7 +516,7 @@ class ModelRunner:
|
||||
):
|
||||
server_args.attention_backend = "fa3"
|
||||
elif _is_hip:
|
||||
server_args.attention_backend = "aiter"
|
||||
server_args.attention_backend = "triton"
|
||||
elif _is_npu:
|
||||
server_args.attention_backend = "ascend"
|
||||
else:
|
||||
|
||||
@@ -165,10 +165,10 @@ DINLINE void start_sync(
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
|
||||
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
|
||||
flag)
|
||||
;
|
||||
}
|
||||
@@ -211,16 +211,16 @@ DINLINE void end_sync(
|
||||
if (threadIdx.x < ngpus) {
|
||||
// simultaneously write to the corresponding flag of all ranks.
|
||||
// Latency = 1 p2p write
|
||||
__scoped_atomic_store_n(
|
||||
__hip_atomic_store(
|
||||
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
|
||||
flag,
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
|
||||
__MEMORY_SCOPE_SYSTEM);
|
||||
__HIP_MEMORY_SCOPE_SYSTEM);
|
||||
// wait until we got true from all ranks
|
||||
while (__scoped_atomic_load_n(
|
||||
while (__hip_atomic_load(
|
||||
&self_sg->end[blockIdx.x][threadIdx.x],
|
||||
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
|
||||
__MEMORY_SCOPE_DEVICE) < flag)
|
||||
__HIP_MEMORY_SCOPE_AGENT) < flag)
|
||||
;
|
||||
}
|
||||
__syncthreads();
|
||||
|
||||
@@ -21,6 +21,7 @@ limitations under the License.
|
||||
|
||||
#include "utils.h"
|
||||
|
||||
#define WARP_SIZE 64
|
||||
#define VEC_SIZE 4
|
||||
using Vec = int4;
|
||||
|
||||
@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
|
||||
int original = v;
|
||||
#pragma unroll
|
||||
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
|
||||
int n = __shfl_up_sync(mask, v, offset);
|
||||
int n = __shfl_up(v, offset);
|
||||
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
|
||||
}
|
||||
return v - original;
|
||||
|
||||
@@ -60,7 +60,7 @@ template <typename T>
|
||||
__device__ float convert_to_float(T x) {
|
||||
if constexpr (std::is_same_v<T, __half>) {
|
||||
return __half2float(x);
|
||||
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
|
||||
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
|
||||
return __bfloat162float(x);
|
||||
} else if constexpr (std::is_same_v<T, float>) {
|
||||
return x;
|
||||
@@ -575,8 +575,8 @@ void topk_softmax(
|
||||
renormalize,
|
||||
stream);
|
||||
} else if (dtype == at::ScalarType::BFloat16) {
|
||||
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
|
||||
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
|
||||
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
|
||||
topk_weights.data_ptr<float>(),
|
||||
topk_indices.data_ptr<int>(),
|
||||
softmax_workspace.data_ptr<float>(),
|
||||
|
||||
@@ -358,25 +358,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
|
||||
#endif
|
||||
|
||||
// add FP8 support
|
||||
#ifndef USE_ROCM
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else // USE_ROCM
|
||||
#if HIP_FP8_TYPE_FNUZ
|
||||
#include <c10/util/Float8_e4m3fnuz.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
#else
|
||||
#if HIP_FP8_TYPE_E4M3
|
||||
#include <c10/util/Float8_e4m3fn.h>
|
||||
using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
#else
|
||||
#error "fp8 is not supported in this processor (arch < gfx942)."
|
||||
#endif // HIP_FP8_TYPE_E4M3
|
||||
#endif // HIP_FP8_TYPE_FNUZ
|
||||
#endif // USE_ROCM
|
||||
// #ifndef USE_ROCM
|
||||
// #include <c10/util/Float8_e4m3fn.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
// #else // USE_ROCM
|
||||
// #if HIP_FP8_TYPE_FNUZ
|
||||
// #include <c10/util/Float8_e4m3fnuz.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fnuz;
|
||||
// constexpr auto FP8_E4M3_MAX = 224.0f;
|
||||
// #else
|
||||
// #if HIP_FP8_TYPE_E4M3
|
||||
// #include <c10/util/Float8_e4m3fn.h>
|
||||
// using FP8_TYPE = c10::Float8_e4m3fn;
|
||||
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
|
||||
// #else
|
||||
// #error "fp8 is not supported in this processor (arch < gfx942)."
|
||||
// #endif // HIP_FP8_TYPE_E4M3
|
||||
// #endif // HIP_FP8_TYPE_FNUZ
|
||||
// #endif // USE_ROCM
|
||||
|
||||
#define FULL_MASK 0xffffffff
|
||||
|
||||
|
||||
100
sgl-kernel/setup_hip.py
Normal file
100
sgl-kernel/setup_hip.py
Normal file
@@ -0,0 +1,100 @@
|
||||
# Copyright 2025 SGLang Team. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
import platform
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
from setuptools import find_packages, setup
|
||||
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
|
||||
|
||||
root = Path(__file__).parent.resolve()
|
||||
arch = platform.machine().lower()
|
||||
|
||||
|
||||
def _get_version():
|
||||
with open(root / "pyproject.toml") as f:
|
||||
for line in f:
|
||||
if line.startswith("version"):
|
||||
return line.split("=")[1].strip().strip('"')
|
||||
|
||||
|
||||
operator_namespace = "sgl_kernel"
|
||||
include_dirs = [
|
||||
root / "include",
|
||||
root / "csrc",
|
||||
]
|
||||
|
||||
sources = [
|
||||
"csrc/allreduce/custom_all_reduce.hip",
|
||||
"csrc/allreduce/quick_all_reduce.cu",
|
||||
"csrc/common_extension_rocm.cc",
|
||||
"csrc/elementwise/activation.cu",
|
||||
"csrc/grammar/apply_token_bitmask_inplace_cuda.cu",
|
||||
"csrc/moe/moe_align_kernel.cu",
|
||||
"csrc/moe/moe_topk_softmax_kernels.cu",
|
||||
"csrc/speculative/eagle_utils.cu",
|
||||
"csrc/kvcacheio/transfer.cu",
|
||||
]
|
||||
|
||||
cxx_flags = [
|
||||
"-O3",
|
||||
"-Wno-switch-bool",
|
||||
"-Wno-macro-redefined",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-w",
|
||||
]
|
||||
libraries = ["c10", "torch", "torch_python"]
|
||||
extra_link_args = ["-Wl,-rpath,$ORIGIN/../../torch/lib", f"-L/usr/lib/{arch}-linux-gnu"]
|
||||
|
||||
hipcc_flags = [
|
||||
"-fPIC",
|
||||
"-O3",
|
||||
"-std=c++17",
|
||||
"-D__HIP_PLATFORM_HCC__=1",
|
||||
"--offload-arch=gfx928",
|
||||
"--offload-arch=gfx936",
|
||||
"--gpu-max-threads-per-block=1024",
|
||||
"-Wno-macro-redefined",
|
||||
"-Wno-deprecated-declarations",
|
||||
"-funroll-loops",
|
||||
"-Rpass-analysis=unroll-loops",
|
||||
"-w",
|
||||
]
|
||||
|
||||
ext_modules = [
|
||||
CUDAExtension(
|
||||
name="sgl_kernel.common_ops",
|
||||
sources=sources,
|
||||
include_dirs=include_dirs,
|
||||
extra_compile_args={
|
||||
"nvcc": hipcc_flags,
|
||||
"cxx": cxx_flags,
|
||||
},
|
||||
libraries=libraries,
|
||||
extra_link_args=extra_link_args,
|
||||
py_limited_api=False,
|
||||
),
|
||||
]
|
||||
|
||||
setup(
|
||||
name="sgl-kernel",
|
||||
version=_get_version(),
|
||||
packages=find_packages(where="python"),
|
||||
package_dir={"": "python"},
|
||||
ext_modules=ext_modules,
|
||||
cmdclass={"build_ext": BuildExtension.with_options(use_ninja=True)},
|
||||
options={"bdist_wheel": {"py_limited_api": "cp39"}},
|
||||
)
|
||||
Reference in New Issue
Block a user