Fuse routed scaling factor in topk_reduce kernel (#6220)

This commit is contained in:
Xiaoyu Zhang
2025-06-08 02:06:50 +08:00
committed by GitHub
parent f5599ef124
commit 515ef4facb
10 changed files with 331 additions and 9 deletions

View File

@@ -1155,6 +1155,7 @@ def inplace_fused_experts(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
) -> None:
fused_experts_impl(
hidden_states,
@@ -1177,6 +1178,8 @@ def inplace_fused_experts(
a1_scale,
a2_scale,
block_shape,
False,
routed_scaling_factor,
)
@@ -1200,6 +1203,7 @@ def inplace_fused_experts_fake(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
routed_scaling_factor: Optional[float] = None,
) -> None:
pass
@@ -1233,6 +1237,7 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return fused_experts_impl(
hidden_states,
@@ -1256,6 +1261,7 @@ def outplace_fused_experts(
a2_scale,
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
@@ -1280,6 +1286,7 @@ def outplace_fused_experts_fake(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
@@ -1314,7 +1321,9 @@ def fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
if inplace:
assert not no_combine, "no combine + inplace makes no sense"
torch.ops.sglang.inplace_fused_experts(
@@ -1337,6 +1346,7 @@ def fused_experts(
a1_scale,
a2_scale,
block_shape,
routed_scaling_factor,
)
return hidden_states
else:
@@ -1361,9 +1371,102 @@ def fused_experts(
a2_scale,
block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
# _moe_sum_reduce_kernel kernel modified from https://github.com/ModelTC/lightllm/blob/main/lightllm/common/fused_moe/moe_sum_reduce.py
@triton.jit
def _moe_sum_reduce_kernel(
input_ptr,
input_stride_0,
input_stride_1,
input_stride_2,
output_ptr,
output_stride_0,
output_stride_1,
token_num: int,
topk_num: int,
hidden_dim: int,
routed_scaling_factor: tl.constexpr,
BLOCK_M: tl.constexpr,
BLOCK_DIM: tl.constexpr,
NUM_STAGE: tl.constexpr,
):
input_stride_0 = tl.cast(input_stride_0, dtype=tl.int64)
input_stride_1 = tl.cast(input_stride_1, dtype=tl.int64)
output_stride_0 = tl.cast(output_stride_0, dtype=tl.int64)
token_block_id = tl.program_id(0)
dim_block_id = tl.program_id(1)
token_start = token_block_id * BLOCK_M
token_end = min((token_block_id + 1) * BLOCK_M, token_num)
dim_start = dim_block_id * BLOCK_DIM
dim_end = min((dim_block_id + 1) * BLOCK_DIM, hidden_dim)
offs_dim = dim_start + tl.arange(0, BLOCK_DIM)
for token_index in range(token_start, token_end):
accumulator = tl.zeros((BLOCK_DIM,), dtype=tl.float32)
input_t_ptr = input_ptr + token_index * input_stride_0 + offs_dim
for i in tl.range(0, topk_num, num_stages=NUM_STAGE):
tmp = tl.load(
input_t_ptr + i * input_stride_1, mask=offs_dim < dim_end, other=0.0
)
accumulator += tmp
accumulator = accumulator * routed_scaling_factor
store_t_ptr = output_ptr + token_index * output_stride_0 + offs_dim
tl.store(
store_t_ptr,
accumulator.to(input_ptr.dtype.element_ty),
mask=offs_dim < dim_end,
)
def moe_sum_reduce_triton(
input: torch.Tensor, output: torch.Tensor, routed_scaling_factor: float
):
assert input.is_contiguous()
assert output.is_contiguous()
token_num, topk_num, hidden_dim = input.shape
assert output.shape[0] == token_num and output.shape[1] == hidden_dim
BLOCK_M = 1
BLOCK_DIM = 2048
NUM_STAGE = 1
num_warps = 8
grid = (
triton.cdiv(token_num, BLOCK_M),
triton.cdiv(hidden_dim, BLOCK_DIM),
)
_moe_sum_reduce_kernel[grid](
input,
*input.stride(),
output,
*output.stride(),
token_num=token_num,
topk_num=topk_num,
hidden_dim=hidden_dim,
routed_scaling_factor=routed_scaling_factor,
BLOCK_M=BLOCK_M,
BLOCK_DIM=BLOCK_DIM,
NUM_STAGE=NUM_STAGE,
num_warps=num_warps,
)
return
@torch.compile
def moe_sum_reduce_torch_compile(x, out, routed_scaling_factor):
torch.sum(x, dim=1, out=out)
out.mul_(routed_scaling_factor)
def fused_experts_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
@@ -1386,6 +1489,7 @@ def fused_experts_impl(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
):
padded_size = padding_size
if (
@@ -1562,6 +1666,9 @@ def fused_experts_impl(
block_shape=block_shape,
)
if routed_scaling_factor is None:
routed_scaling_factor = 1.0
if no_combine:
pass
elif _is_hip:
@@ -1570,20 +1677,28 @@ def fused_experts_impl(
out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
if topk_ids.shape[1] == 1:
if topk_ids.shape[1] == 1 and routed_scaling_factor == 1.0:
pass # we write directly into out_hidden_states
elif topk_ids.shape[1] == 2:
elif topk_ids.shape[1] == 2 and routed_scaling_factor == 1.0:
torch.add(
intermediate_cache3[:, 0],
intermediate_cache3[:, 1],
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
).squeeze(dim=1)
elif topk_ids.shape[1] > 2:
torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
)
else:
# According to micro benchmark results, torch.compile can get better performance for small token.
if tokens_in_chunk <= 32:
moe_sum_reduce_torch_compile(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
else:
moe_sum_reduce_triton(
intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx],
routed_scaling_factor,
)
return out_hidden_states
@@ -1695,4 +1810,5 @@ def fused_moe(
a2_scale=a2_scale,
block_shape=block_shape,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -225,6 +225,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def forward_cpu(

View File

@@ -411,4 +411,5 @@ class BlockInt8MoEMethod:
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -317,6 +317,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
apply_router_weight_on_input=apply_router_weight_on_input,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -1030,6 +1030,7 @@ class Fp8MoEMethod:
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
def maybe_apply_hip_fused_experts(

View File

@@ -388,6 +388,7 @@ class MoeWNA16Method:
w2_zp=layer.w2_qzeros if has_zp else None,
block_shape=[0, layer.group_size],
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
@staticmethod

View File

@@ -328,4 +328,5 @@ class W8A8FP8MoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -268,4 +268,5 @@ class W8A8Int8MoEMethod:
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)

View File

@@ -346,7 +346,7 @@ class DeepseekV2MoE(nn.Module):
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1: