Support penalty in overlap mode; return logprob with chunked prefill; improve benchmark scripts (#3988)
Co-authored-by: SangBin Cho <rkooo567@gmail.com> Co-authored-by: dhou-xai <dhou@x.ai> Co-authored-by: Hanming Lu <hanming_lu@berkeley.edu>
This commit is contained in:
@@ -144,6 +144,73 @@ def silu_and_mul_triton_kernel(
|
||||
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def tanh(x):
|
||||
return 2 * tl.sigmoid(2 * x) - 1
|
||||
|
||||
|
||||
@triton.jit
|
||||
def gelu_and_mul_triton_kernel(
|
||||
gateup_output,
|
||||
down_input,
|
||||
hidden_size,
|
||||
reorder_topk_ids,
|
||||
scales,
|
||||
start_expert_id,
|
||||
end_expert_id,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
):
|
||||
InDtype = gateup_output.dtype.element_ty
|
||||
OutDtype = down_input.dtype.element_ty
|
||||
|
||||
half_hidden_size = hidden_size // 2
|
||||
|
||||
pid = tl.program_id(0)
|
||||
expert_id = tl.load(reorder_topk_ids + pid)
|
||||
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
||||
gateup_output_ptr = gateup_output + pid * hidden_size
|
||||
gate_output_ptr = gateup_output_ptr
|
||||
up_output_ptr = gateup_output_ptr + half_hidden_size
|
||||
down_input_ptr = down_input + pid * half_hidden_size
|
||||
|
||||
if scales is not None:
|
||||
scale = tl.load(scales + expert_id - start_expert_id)
|
||||
scale = (1 / scale).to(InDtype)
|
||||
else:
|
||||
scale = 1
|
||||
|
||||
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
||||
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offset < half_hidden_size
|
||||
|
||||
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
||||
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
||||
|
||||
# gelu & mul & quantize
|
||||
# https://pytorch.org/docs/stable/generated/torch.nn.GELU.html
|
||||
# sqrt(2/pi)
|
||||
kAlpha = 0.7978845608028654
|
||||
gate_output = (
|
||||
0.5
|
||||
* gate_output
|
||||
* (
|
||||
1
|
||||
+ tanh(
|
||||
kAlpha
|
||||
* (
|
||||
gate_output
|
||||
+ 0.044715 * gate_output * gate_output * gate_output
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
gate_output = gate_output.to(InDtype)
|
||||
|
||||
gelu_mul_output = gate_output * up_output * scale
|
||||
gelu_mul_output = gelu_mul_output.to(OutDtype)
|
||||
tl.store(down_input_ptr + offset, gelu_mul_output, mask=mask)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def post_reorder_triton_kernel(
|
||||
down_output_ptr,
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
gelu_and_mul_triton_kernel,
|
||||
grouped_gemm_triton,
|
||||
post_reorder_triton_kernel,
|
||||
pre_reorder_triton_kernel,
|
||||
@@ -296,6 +297,17 @@ class EPMoE(torch.nn.Module):
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
elif self.activation == "gelu":
|
||||
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
||||
gateup_output,
|
||||
down_input,
|
||||
gateup_output.shape[1],
|
||||
reorder_topk_ids,
|
||||
self.w2_input_scale,
|
||||
self.start_expert_id,
|
||||
self.end_expert_id,
|
||||
BLOCK_SIZE=512,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported activation: {self.activation=}")
|
||||
|
||||
|
||||
@@ -24,6 +24,8 @@ def fused_moe_forward_native(
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
|
||||
@@ -23,7 +23,7 @@ from sglang.srt.utils import (
|
||||
is_hip,
|
||||
)
|
||||
|
||||
is_hip_flag = is_hip()
|
||||
is_hip_ = is_hip()
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -487,6 +487,7 @@ def invoke_fused_moe_kernel(
|
||||
use_int8_w8a8: bool,
|
||||
use_int8_w8a16: bool,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> None:
|
||||
assert topk_weights.stride(1) == 1
|
||||
assert sorted_token_ids.stride(0) == 1
|
||||
@@ -646,7 +647,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 8,
|
||||
"num_stages": 2 if is_hip_flag else 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
}
|
||||
if M <= E:
|
||||
config = {
|
||||
@@ -655,7 +656,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": 128,
|
||||
"GROUP_SIZE_M": 1,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_flag else 4,
|
||||
"num_stages": 2 if is_hip_ else 4,
|
||||
}
|
||||
else:
|
||||
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
||||
@@ -665,7 +666,7 @@ def get_default_config(
|
||||
"BLOCK_SIZE_K": block_shape[1],
|
||||
"GROUP_SIZE_M": 32,
|
||||
"num_warps": 4,
|
||||
"num_stages": 2 if is_hip_flag else 3,
|
||||
"num_stages": 2 if is_hip_ else 3,
|
||||
}
|
||||
else:
|
||||
config = {
|
||||
@@ -814,6 +815,7 @@ def outplace_fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return fused_experts_impl(
|
||||
hidden_states,
|
||||
@@ -831,6 +833,7 @@ def outplace_fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -849,6 +852,7 @@ def outplace_fused_experts_fake(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return torch.empty_like(hidden_states)
|
||||
|
||||
@@ -877,8 +881,10 @@ def fused_experts(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
if inplace:
|
||||
assert not no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
@@ -912,6 +918,7 @@ def fused_experts(
|
||||
a1_scale,
|
||||
a2_scale,
|
||||
block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
|
||||
@@ -931,6 +938,7 @@ def fused_experts_impl(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
padded_size = padding_size
|
||||
if not use_fp8_w8a8 or not use_int8_w8a8 or block_shape is not None:
|
||||
@@ -987,7 +995,14 @@ def fused_experts_impl(
|
||||
|
||||
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
|
||||
|
||||
if inplace:
|
||||
if no_combine:
|
||||
assert not inplace
|
||||
out_hidden_states = torch.empty(
|
||||
(num_tokens, topk_ids.shape[1], w2.shape[1]),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
elif inplace:
|
||||
out_hidden_states = hidden_states
|
||||
else:
|
||||
out_hidden_states = torch.empty_like(hidden_states)
|
||||
@@ -1057,7 +1072,11 @@ def fused_experts_impl(
|
||||
invoke_fused_moe_kernel(
|
||||
intermediate_cache2,
|
||||
w2,
|
||||
intermediate_cache3,
|
||||
(
|
||||
intermediate_cache3
|
||||
if not no_combine and topk_ids.shape[1] != 1
|
||||
else out_hidden_states[begin_chunk_idx:end_chunk_idx]
|
||||
),
|
||||
a2_scale,
|
||||
w2_scale,
|
||||
curr_topk_weights,
|
||||
@@ -1075,16 +1094,16 @@ def fused_experts_impl(
|
||||
block_shape=block_shape,
|
||||
)
|
||||
|
||||
if is_hip_flag:
|
||||
if no_combine:
|
||||
pass
|
||||
elif is_hip_:
|
||||
ops.moe_sum(
|
||||
intermediate_cache3.view(*intermediate_cache3.shape),
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
||||
)
|
||||
else:
|
||||
if topk_ids.shape[1] == 1:
|
||||
out_hidden_states[begin_chunk_idx:end_chunk_idx].copy_(
|
||||
intermediate_cache3[:, 0]
|
||||
)
|
||||
pass # we write directly into out_hidden_states
|
||||
elif topk_ids.shape[1] == 2:
|
||||
torch.add(
|
||||
intermediate_cache3[:, 0],
|
||||
@@ -1122,6 +1141,7 @@ def fused_moe(
|
||||
a1_scale: Optional[torch.Tensor] = None,
|
||||
a2_scale: Optional[torch.Tensor] = None,
|
||||
block_shape: Optional[List[int]] = None,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function computes a Mixture of Experts (MoE) layer using two sets of
|
||||
@@ -1191,4 +1211,5 @@ def fused_moe(
|
||||
a1_scale=a1_scale,
|
||||
a2_scale=a2_scale,
|
||||
block_shape=block_shape,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
@@ -125,6 +125,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
return self.forward(
|
||||
x=x,
|
||||
@@ -138,6 +140,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
activation=activation,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
@@ -153,6 +157,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
) -> torch.Tensor:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
@@ -171,6 +177,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
from aiter.fused_moe import fused_experts_ck
|
||||
|
||||
assert activation == "silu", f"{activation=} is not supported."
|
||||
assert not no_combine, "unsupported"
|
||||
|
||||
return fused_experts_ck(
|
||||
hidden_states=x,
|
||||
@@ -186,8 +193,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
no_combine=no_combine,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
@@ -202,6 +210,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
inplace: bool = True,
|
||||
) -> torch.Tensor:
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
@@ -241,6 +250,7 @@ class FusedMoE(torch.nn.Module):
|
||||
reduce_results: Whether to all all_reduce on the output of the layer
|
||||
renomalize: Whether to renormalize the logits in the fused_moe kernel
|
||||
quant_config: Quantization configure.
|
||||
inplace: suggestion to compute inplace (modify input activation).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -262,6 +272,8 @@ class FusedMoE(torch.nn.Module):
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
use_presharded_weights: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
@@ -285,6 +297,9 @@ class FusedMoE(torch.nn.Module):
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
self.inplace = inplace
|
||||
self.no_combine = no_combine
|
||||
|
||||
if quant_config is None:
|
||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||
@@ -304,7 +319,6 @@ class FusedMoE(torch.nn.Module):
|
||||
params_dtype=params_dtype,
|
||||
weight_loader=self.weight_loader,
|
||||
)
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
|
||||
def _load_per_tensor_weight_scale(
|
||||
self,
|
||||
@@ -598,6 +612,8 @@ class FusedMoE(torch.nn.Module):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
activation=self.activation,
|
||||
inplace=self.inplace,
|
||||
no_combine=self.no_combine,
|
||||
)
|
||||
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user