move more files under srt/utils (#11285)

This commit is contained in:
Lianmin Zheng
2025-10-09 16:46:15 -07:00
committed by GitHub
parent 758b887ad1
commit 9b8ebb2798
28 changed files with 96 additions and 55 deletions

View File

@@ -12,7 +12,10 @@ from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import add_prefix, align, is_cuda, is_hip, is_npu
if is_cuda():
import deep_gemm
try:
import deep_gemm
except ImportError as e:
deep_gemm = e
from sglang.srt.layers.attention.nsa.utils import NSA_DUAL_STREAM, NSA_USE_REAL_INDEXER
from sglang.srt.layers.dp_attention import get_attention_tp_group

View File

@@ -30,9 +30,9 @@ from sglang.srt.layers.quantization.modelopt_quant import (
ModelOptNvFp4FusedMoEMethod,
)
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.offloader import get_offloader
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
from sglang.srt.utils.offloader import get_offloader
if TYPE_CHECKING:
from sglang.srt.layers.moe.token_dispatcher import (

View File

@@ -11,7 +11,7 @@ _is_hip = is_hip()
@triton.jit
def fused_moe_router_kernel(
def fused_moe_router_cudacore_kernel(
input_ptr, # input (bs, hidden_dim)
moe_router_weight_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
@@ -114,7 +114,7 @@ def fused_moe_router_kernel(
# assert not moe_renormalize, "moe weight renormalization not implemented"
def fused_moe_router_impl(
def fused_moe_router_cudacore(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
@@ -138,7 +138,7 @@ def fused_moe_router_impl(
),
}
fused_moe_router_kernel[(bs,)](
fused_moe_router_cudacore_kernel[(bs,)](
x,
router_weight,
topk_weights,
@@ -157,7 +157,7 @@ def fused_moe_router_impl(
@triton.jit
def fused_moe_router_large_bs_kernel(
def fused_moe_router_tensorcore_kernel(
a_ptr, # input (bs, hidden_dim)
b_ptr, # input (num_experts, hidden_dim)
topk_weights_ptr, # output (bs, topk)
@@ -167,12 +167,15 @@ def fused_moe_router_large_bs_kernel(
topk: tl.constexpr, # only support topk <= 2
moe_softcapping: tl.constexpr,
moe_renormalize: tl.constexpr, # not supported
correction_bias_ptr,
is_correction_bias: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
stride_am: tl.constexpr,
stride_bn: tl.constexpr,
dp_attn_workaround_flag: tl.constexpr,
):
# 1. get block id
@@ -217,6 +220,20 @@ def fused_moe_router_large_bs_kernel(
exped = tl.exp(2 * logits_scaled)
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
# Add bias after softcapping
if is_correction_bias:
bias = tl.load(
correction_bias_ptr + tl.arange(0, BLOCK_SIZE_N)[None, :],
mask=expert_mask.T,
other=0.0,
)
logits_softcapped = logits_softcapped + bias
if dp_attn_workaround_flag:
logits_softcapped = tl.where(
logits_softcapped != logits_softcapped, -1e9, logits_softcapped
)
# 5. top1
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
cond_top1 = arange_block_size_n < num_experts
@@ -266,7 +283,7 @@ def fused_moe_router_large_bs_kernel(
)
def fused_moe_router_large_bs_impl(
def fused_moe_router_tensorcore(
x: torch.Tensor,
router_weight: torch.Tensor,
topk: int,
@@ -274,6 +291,7 @@ def fused_moe_router_large_bs_impl(
BLOCK_SIZE_M: int,
BLOCK_SIZE_N: int,
BLOCK_SIZE_K: int,
correction_bias: Optional[torch.Tensor] = None,
):
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
bs, hidden_dim = x.shape
@@ -285,10 +303,17 @@ def fused_moe_router_large_bs_impl(
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
is_correction_bias = correction_bias is not None
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
fused_moe_router_large_bs_kernel[grid](
# TODO(ch-wan): temporary workaround for dp attention. We should support masked
# router to skip padded tokens.
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
dp_attn_workaround_flag = is_dp_attention_enabled()
fused_moe_router_tensorcore_kernel[grid](
a_ptr=x,
b_ptr=router_weight,
topk_weights_ptr=topk_weights,
@@ -299,11 +324,14 @@ def fused_moe_router_large_bs_impl(
moe_softcapping=moe_softcapping,
moe_renormalize=False,
K=hidden_dim,
correction_bias_ptr=correction_bias,
is_correction_bias=is_correction_bias,
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
stride_am=hidden_dim,
stride_bn=hidden_dim,
dp_attn_workaround_flag=dp_attn_workaround_flag,
)
return topk_weights, topk_ids
@@ -316,6 +344,7 @@ def fused_moe_router_shim(
topk,
renormalize,
correction_bias: Optional[torch.Tensor] = None,
enable_deterministic_inference: bool = False,
):
assert not renormalize
assert (
@@ -324,16 +353,22 @@ def fused_moe_router_shim(
)
bs, hidden_dim = hidden_states.shape
num_experts = gating_output.shape[0]
BLOCK_SIZE_M = 32
BLOCK_SIZE_N = 16
BLOCK_SIZE_K = 256
BLOCK_SIZE_N = max(num_experts, 16)
BLOCK_SIZE_K = (
256 if num_experts < 256 else 64
) # if experts are large, need to use smaller k block or shared memory OOM
if (
bs >= 512
and topk <= 2
and num_experts <= BLOCK_SIZE_N
(bs >= 512 or num_experts > 8)
and hidden_dim % BLOCK_SIZE_K == 0
# we keep using single kernel to avoid non-deterministic behavior
and not enable_deterministic_inference
):
return fused_moe_router_large_bs_impl(
# if large batch size or large expert, use kernel that uses tensorcore in matmul
return fused_moe_router_tensorcore(
x=hidden_states,
router_weight=gating_output,
topk=topk,
@@ -341,9 +376,11 @@ def fused_moe_router_shim(
BLOCK_SIZE_M=BLOCK_SIZE_M,
BLOCK_SIZE_N=BLOCK_SIZE_N,
BLOCK_SIZE_K=BLOCK_SIZE_K,
correction_bias=correction_bias,
)
else:
return fused_moe_router_impl(
# if smaller, use kernel that does not use tensorcore in matmul
return fused_moe_router_cudacore(
x=hidden_states,
router_weight=gating_output,
topk=topk,
@@ -380,11 +417,10 @@ class FusedMoeRouter:
renormalize=False,
)
def forward_vllm(
def forward_torch(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
# g, _ = self.router_linear.forward(x)
g = x.float() @ self.router_linear.weight.T.float()
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping

View File

@@ -2,11 +2,10 @@ from typing import Callable, List, Optional, Tuple
import torch
from sglang.srt import offloader
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.utils import is_sm100_supported
from sglang.srt.utils import is_sm100_supported, offloader
try:
from vllm import _custom_ops as ops
@@ -29,7 +28,6 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.utils import (
align,
ceil_div,
get_bool_env_var,
get_cuda_version,
get_device_capability,