Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)

Co-authored-by: SangBin Cho <rkooo567@gmail.com>
Co-authored-by: dhou-xai <dhou@x.ai>
Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
Lianmin Zheng
2025-03-03 00:12:04 -08:00
parent 0194948fd9
commit ac2387279e
86 changed files with 4116 additions and 2015 deletions

View File

@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
@triton.jit
def tanh(x):
return 2 * tl.sigmoid(2 * x) - 1
@triton.jit
def gelu_and_mul_triton_kernel(
gateup_output,
down_input,
hidden_size,
reorder_topk_ids,
scales,
start_expert_id,
end_expert_id,
BLOCK_SIZE: tl.constexpr,
):
InDtype = gateup_output.dtype.element_ty
OutDtype = down_input.dtype.element_ty
half_hidden_size = hidden_size // 2
pid = tl.program_id(0)
expert_id = tl.load(reorder_topk_ids + pid)
if expert_id >= start_expert_id and expert_id <= end_expert_id:
gateup_output_ptr = gateup_output + pid * hidden_size
gate_output_ptr = gateup_output_ptr
up_output_ptr = gateup_output_ptr + half_hidden_size
down_input_ptr = down_input + pid * half_hidden_size
if scales is not None:
scale = tl.load(scales + expert_id - start_expert_id)
scale = (1 / scale).to(InDtype)
else:
scale = 1
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
offset = start_offset + tl.arange(0, BLOCK_SIZE)
mask = offset < half_hidden_size
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
up_output = tl.load(up_output_ptr + offset, mask=mask)
# gelu & mul & quantize
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
# sqrt(2/pi)
kAlpha = 0.7978845608028654
gate_output = (
0.5
* gate_output
* (
1
+ tanh(
kAlpha
* (
gate_output
+ 0.044715 * gate_output * gate_output * gate_output
)
)
)
)
gate_output = gate_output.to(InDtype)
gelu_mul_output = gate_output * up_output * scale
gelu_mul_output = gelu_mul_output.to(OutDtype)
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
@triton.jit
def post_reorder_triton_kernel(
down_output_ptr,

View File

@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.moe.ep_moe.kernels import (
gelu_and_mul_triton_kernel,
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
self.end_expert_id,
BLOCK_SIZE=512,
)
elif self.activation == "gelu":
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")

View File

@@ -24,6 +24,8 @@ def fused_moe_forward_native(
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,

View File

@@ -23,7 +23,7 @@ from sglang.srt.utils import (
is_hip,
)
is_hip_flag = is_hip()
is_hip_ = is_hip()
logger = logging.getLogger(__name__)
@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
use_int8_w8a8: bool,
use_int8_w8a16: bool,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
@@ -646,7 +647,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2 if is_hip_flag else 4,
"num_stages": 2 if is_hip_ else 4,
}
if M <= E:
config = {
@@ -655,7 +656,7 @@ def get_default_config(
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2 if is_hip_flag else 4,
"num_stages": 2 if is_hip_ else 4,
}
else:
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
@@ -665,7 +666,7 @@ def get_default_config(
"BLOCK_SIZE_K": block_shape[1],
"GROUP_SIZE_M": 32,
"num_warps": 4,
"num_stages": 2 if is_hip_flag else 3,
"num_stages": 2 if is_hip_ else 3,
}
else:
config = {
@@ -814,6 +815,7 @@ def outplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
@@ -831,6 +833,7 @@ def outplace_fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
)
@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -877,8 +881,10 @@ def fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
):
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
hidden_states,
w1,
@@ -912,6 +918,7 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
no_combine=no_combine,
)
@@ -931,6 +938,7 @@ def fused_experts_impl(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
):
padded_size = padding_size
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
@@ -987,7 +995,14 @@ def fused_experts_impl(
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
if inplace:
if no_combine:
assert not inplace
out_hidden_states = torch.empty(
(num_tokens, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
elif inplace:
out_hidden_states = hidden_states
else:
out_hidden_states = torch.empty_like(hidden_states)
@@ -1057,7 +1072,11 @@ def fused_experts_impl(
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
(
intermediate_cache3
if not no_combine and topk_ids.shape[1] != 1
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
),
a2_scale,
w2_scale,
curr_topk_weights,
@@ -1075,16 +1094,16 @@ def fused_experts_impl(
block_shape=block_shape,
)
if is_hip_flag:
if no_combine:
pass
elif is_hip_:
ops.moe_sum(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
if topk_ids.shape[1] == 1:
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
intermediate_cache3[:, 0]
)
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2:
torch.add(
intermediate_cache3[:, 0],
@@ -1122,6 +1141,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1191,4 +1211,5 @@ def fused_moe(
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
no_combine=no_combine,
)

View File

@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
return self.forward(
x=x,
@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
activation=activation,
inplace=inplace,
no_combine=no_combine,
)
def forward_cuda(
@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
inplace: bool = True,
no_combine: bool = False,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
from aiter.fused_moe import fused_experts_ck
assert activation == "silu", f"{activation=} is not supported."
assert not no_combine, "unsupported"
return fused_experts_ck(
hidden_states=x,
@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
inplace=inplace and not no_combine,
activation=activation,
no_combine=no_combine,
)
def forward_cpu(
@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
inplace: bool = True,
) -> torch.Tensor:
return moe_forward_native(
layer,
@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
reduce_results: Whether to all all_reduce on the output of the layer
renomalize: Whether to renormalize the logits in the fused_moe kernel
quant_config: Quantization configure.
inplace: suggestion to compute inplace (modify input activation).
"""
def __init__(
@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
use_presharded_weights: bool = False,
inplace: bool = True,
no_combine: bool = False,
):
super().__init__()
@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
self.activation = activation
self.use_presharded_weights = use_presharded_weights
self.inplace = inplace
self.no_combine = no_combine
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
params_dtype=params_dtype,
weight_loader=self.weight_loader,
)
self.use_presharded_weights = use_presharded_weights
def _load_per_tensor_weight_scale(
self,
@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
activation=self.activation,
inplace=self.inplace,
no_combine=self.no_combine,
)
if self.reduce_results and self.tp_size > 1: