[1/N] MoE Refactor: refactor select_experts (#7966)
This commit is contained in:
@@ -29,15 +29,18 @@ class CustomOp(nn.Module):
|
||||
|
||||
self._original_forward_method = self._forward_method
|
||||
# NOTE: Temporarily workaround MoE
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to only use torch.compile when bs=1
|
||||
if "FusedMoE" in self.__class__.__name__:
|
||||
if num_tokens == 1:
|
||||
from sglang.srt.layers.moe.fused_moe_native import (
|
||||
fused_moe_forward_native,
|
||||
)
|
||||
|
||||
# The performance of torch.compile on this layer is not always good when bs > 1,
|
||||
# so we decide to only use torch.compile when bs =1
|
||||
self._forward_method = fused_moe_forward_native
|
||||
elif "TopK" in self.__class__.__name__:
|
||||
if num_tokens == 1:
|
||||
self._forward_method = self.forward_native
|
||||
else:
|
||||
self._forward_method = self.forward_native
|
||||
self.is_torch_compile = True
|
||||
|
||||
@@ -756,7 +756,7 @@ class QKVParallelLinear(ColumnParallelLinear):
|
||||
bias: bool = True,
|
||||
skip_bias_add: bool = False,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
quant_config: Optional["QuantizationConfig"] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
tp_rank: Optional[int] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
import logging
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import einops
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
||||
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
||||
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
ep_gather,
|
||||
ep_scatter,
|
||||
@@ -28,7 +24,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||
tma_align_input_scale,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
@@ -162,16 +158,9 @@ class EPMoE(torch.nn.Module):
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
use_per_token_if_dynamic: bool = True,
|
||||
@@ -189,24 +178,12 @@ class EPMoE(torch.nn.Module):
|
||||
self.layer_id = layer_id
|
||||
self.num_experts = num_experts
|
||||
assert self.num_experts % self.tp_size == 0
|
||||
assert (
|
||||
num_fused_shared_experts == 0
|
||||
), "num_fused_shared_experts is not supported in EP"
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
|
||||
self.start_expert_id = self.tp_rank * self.num_experts_per_partition
|
||||
self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
|
||||
|
||||
self.top_k = top_k
|
||||
self.intermediate_size = intermediate_size
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.topk_group = topk_group
|
||||
self.correction_bias = correction_bias
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.activation = activation
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
||||
@@ -311,33 +288,24 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return (local_num_experts, expert_map)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
||||
return self.forward_deepgemm(hidden_states, router_logits)
|
||||
return self.forward_deepgemm(hidden_states, topk_output)
|
||||
else:
|
||||
return self.forward_normal(hidden_states, router_logits)
|
||||
return self.forward_normal(hidden_states, topk_output)
|
||||
|
||||
def forward_deepgemm(
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
):
|
||||
assert self.quant_method is not None
|
||||
assert self.activation == "silu"
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
if not self.use_block_quant:
|
||||
# Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
|
||||
@@ -469,8 +437,10 @@ class EPMoE(torch.nn.Module):
|
||||
)
|
||||
return output
|
||||
|
||||
def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
assert self.quant_method is not None
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
hidden_states_shape = hidden_states.shape
|
||||
hidden_states_dtype = hidden_states.dtype
|
||||
hidden_states_device = hidden_states.device
|
||||
@@ -481,23 +451,6 @@ class EPMoE(torch.nn.Module):
|
||||
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
),
|
||||
)
|
||||
|
||||
if self.use_w4afp8:
|
||||
local_topk_ids = topk_ids
|
||||
if self.expert_map is not None:
|
||||
@@ -916,16 +869,9 @@ class DeepEPMoE(EPMoE):
|
||||
intermediate_size: int,
|
||||
layer_id: int,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
deepep_mode: DeepEPMode = DeepEPMode.auto,
|
||||
@@ -937,16 +883,9 @@ class DeepEPMoE(EPMoE):
|
||||
intermediate_size=intermediate_size,
|
||||
layer_id=layer_id,
|
||||
params_dtype=params_dtype,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
topk_group=topk_group,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=prefix,
|
||||
correction_bias=correction_bias,
|
||||
custom_routing_function=custom_routing_function,
|
||||
activation=activation,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
@@ -9,21 +9,14 @@ import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
def fused_moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -34,20 +27,7 @@ def fused_moe_forward_native(
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
torch_native=True,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
w13_weights = layer.w13_weight[topk_ids]
|
||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||
@@ -67,15 +47,8 @@ def fused_moe_forward_native(
|
||||
def moe_forward_native(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -86,20 +59,7 @@ def moe_forward_native(
|
||||
if apply_router_weight_on_input:
|
||||
raise NotImplementedError()
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
torch_native=True,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
|
||||
# Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
|
||||
len_experts = layer.num_experts
|
||||
|
||||
@@ -6,13 +6,13 @@ import functools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_token_group_quant_fp8,
|
||||
scaled_fp8_quant,
|
||||
@@ -1328,8 +1328,7 @@ def fused_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
@@ -1348,7 +1347,7 @@ def fused_experts(
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if inplace:
|
||||
assert not no_combine, "no combine + inplace makes no sense"
|
||||
torch.ops.sglang.inplace_fused_experts(
|
||||
@@ -1732,17 +1731,10 @@ def fused_moe(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
gating_output: torch.Tensor,
|
||||
topk: int,
|
||||
renormalize: bool,
|
||||
topk_output: TopKOutput,
|
||||
inplace: bool = False,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
use_fp8_w8a8: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int8_w8a16: bool = False,
|
||||
@@ -1766,16 +1758,9 @@ def fused_moe(
|
||||
- hidden_states (torch.Tensor): The input tensor to the MoE layer.
|
||||
- w1 (torch.Tensor): The first set of expert weights.
|
||||
- w2 (torch.Tensor): The second set of expert weights.
|
||||
- gating_output (torch.Tensor): The output of the gating operation
|
||||
(before softmax).
|
||||
- topk (int): The number of top-k experts to select.
|
||||
- renormalize (bool): If True, renormalize the top-k weights to sum to 1.
|
||||
- topk_output (TopKOutput): The top-k output of the experts.
|
||||
- inplace (bool): If True, perform the operation in-place.
|
||||
Defaults to False.
|
||||
- num_expert_group: Optional[int]: additional parameter for grouped_topk
|
||||
- topk_group: Optional[int]: additional parameter for grouped_topk
|
||||
- use_grouped_topk: If True, use grouped_topk instead of fused_topk
|
||||
note: Deepseek V2/V3/R1 series models use grouped_topk
|
||||
- use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
|
||||
products for w1 and w2. Defaults to False.
|
||||
- use_int8_w8a8 (bool): If True, use int8 arithmetic to compute the inner
|
||||
@@ -1799,28 +1784,12 @@ def fused_moe(
|
||||
Returns:
|
||||
- torch.Tensor: The output tensor after applying the MoE layer.
|
||||
"""
|
||||
# Check constraints.
|
||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=gating_output,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states,
|
||||
w1,
|
||||
w2,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional, Tuple
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
@@ -11,6 +11,7 @@ from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce,
|
||||
)
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.base_config import (
|
||||
QuantizationConfig,
|
||||
QuantizeMethodBase,
|
||||
@@ -59,22 +60,15 @@ class FusedMoE(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
num_experts: int,
|
||||
top_k: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
top_k: Optional[int] = None,
|
||||
layer_id: Optional[int] = None,
|
||||
params_dtype: Optional[torch.dtype] = None,
|
||||
reduce_results: bool = False,
|
||||
renormalize: bool = True,
|
||||
use_grouped_topk: bool = False,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
topk_group: Optional[int] = None,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
tp_size: Optional[int] = None,
|
||||
prefix: str = "",
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_presharded_weights: bool = False,
|
||||
@@ -89,6 +83,7 @@ class FusedMoE(torch.nn.Module):
|
||||
if params_dtype is None:
|
||||
params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.top_k = top_k
|
||||
self.hidden_size = hidden_size
|
||||
self.tp_size = (
|
||||
tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
|
||||
@@ -126,19 +121,9 @@ class FusedMoE(torch.nn.Module):
|
||||
self.ep_rank = 0
|
||||
self.local_num_experts = num_experts
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
self.top_k = top_k
|
||||
assert intermediate_size % self.tp_size == 0
|
||||
self.intermediate_size_per_partition = intermediate_size // self.tp_size
|
||||
self.reduce_results = reduce_results
|
||||
self.renormalize = renormalize
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
if self.use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.topk_group = topk_group
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.activation = activation
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
self.use_presharded_weights = use_presharded_weights
|
||||
@@ -562,22 +547,14 @@ class FusedMoE(torch.nn.Module):
|
||||
)
|
||||
return
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
renormalize=self.renormalize,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
topk_output=topk_output,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
|
||||
@@ -12,12 +12,15 @@
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import Callable, Optional
|
||||
from typing import TYPE_CHECKING, Callable, NamedTuple, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from sglang.srt.custom_op import CustomOp
|
||||
from sglang.srt.eplb import expert_location_dispatch
|
||||
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
||||
from sglang.srt.eplb.expert_location_dispatch import (
|
||||
@@ -52,6 +55,168 @@ if _use_aiter:
|
||||
except ImportError:
|
||||
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
||||
|
||||
if _is_npu:
|
||||
import torch_npu
|
||||
|
||||
|
||||
class TopKOutput(NamedTuple):
|
||||
topk_weights: torch.Tensor
|
||||
topk_ids: torch.Tensor
|
||||
router_logits: torch.Tensor
|
||||
|
||||
|
||||
class TopK(CustomOp):
|
||||
|
||||
# TODO(ch-wan): support triton_kernels
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
top_k: int,
|
||||
*,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
renormalize: bool = True,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
):
|
||||
# NOTE: scoring_func is not used for now, but we keep it for future use
|
||||
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
||||
super().__init__()
|
||||
if use_grouped_topk:
|
||||
assert num_expert_group is not None and topk_group is not None
|
||||
self.top_k = top_k
|
||||
self.use_grouped_topk = use_grouped_topk
|
||||
self.renormalize = renormalize
|
||||
self.topk_group = topk_group
|
||||
self.num_expert_group = num_expert_group
|
||||
self.num_fused_shared_experts = num_fused_shared_experts
|
||||
self.custom_routing_function = custom_routing_function
|
||||
self.correction_bias = correction_bias
|
||||
self.routed_scaling_factor = routed_scaling_factor
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
torch_native = True
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
torch_native = False
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
*,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
) -> TopKOutput:
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
return torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=self.top_k,
|
||||
bias=self.correction_bias,
|
||||
k_group=self.topk_group,
|
||||
group_count=self.num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20),
|
||||
)
|
||||
else:
|
||||
torch_native = True
|
||||
return select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=self.use_grouped_topk,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
correction_bias=self.correction_bias,
|
||||
torch_native=torch_native,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=num_token_non_padded,
|
||||
expert_location_dispatch_info=expert_location_dispatch_info,
|
||||
)
|
||||
|
||||
|
||||
def fused_topk_torch_native(
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -436,8 +601,9 @@ def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
*,
|
||||
use_grouped_topk: bool = False,
|
||||
renormalize: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
@@ -447,7 +613,7 @@ def select_experts(
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
num_token_non_padded: Optional[torch.Tensor] = None,
|
||||
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
||||
):
|
||||
) -> TopKOutput:
|
||||
router_logits, correction_bias = (
|
||||
expert_location_dispatch.transform_select_experts_inputs(
|
||||
router_logits=router_logits,
|
||||
@@ -522,4 +688,4 @@ def select_experts(
|
||||
|
||||
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
||||
|
||||
return topk_weights, topk_ids
|
||||
return TopKOutput(topk_weights, topk_ids, router_logits)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import inspect
|
||||
from typing import Callable, Dict, Optional, Type, Union
|
||||
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -65,6 +67,9 @@ from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
||||
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
# Base quantization methods that don't depend on vllm
|
||||
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
||||
"fp8": Fp8Config,
|
||||
@@ -186,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -208,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
||||
"self": self,
|
||||
"layer": layer,
|
||||
"x": x,
|
||||
"router_logits": router_logits,
|
||||
"top_k": top_k,
|
||||
"renormalize": renormalize,
|
||||
"use_grouped_topk": use_grouped_topk,
|
||||
"topk_group": topk_group,
|
||||
"num_expert_group": num_expert_group,
|
||||
"custom_routing_function": custom_routing_function,
|
||||
"topk_output": topk_output,
|
||||
}
|
||||
if correction_bias is not None:
|
||||
if not has_correction_bias:
|
||||
raise ValueError(
|
||||
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
|
||||
)
|
||||
kwargs["e_score_correction_bias"] = correction_bias
|
||||
return original_apply(**kwargs)
|
||||
|
||||
setattr(class_obj, "apply", new_apply)
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
@@ -33,6 +33,9 @@ from sglang.srt.layers.quantization.scalar_type import scalar_types
|
||||
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import replace_parameter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
|
||||
@@ -737,45 +740,19 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
scoring_func == "softmax"
|
||||
), "Only softmax score func is supported for now."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -1,12 +1,16 @@
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Type
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
class QuantizeMethodBase(ABC):
|
||||
"""Base class for different quantized methods."""
|
||||
@@ -88,19 +92,22 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
):
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError()
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class QuantizationConfig(ABC):
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn import Module
|
||||
@@ -21,6 +21,9 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
||||
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -344,15 +347,8 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -360,30 +356,13 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
# Expert fusion with INT8 quantization
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
|
||||
@@ -1,15 +1,17 @@
|
||||
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import logging
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
import torch
|
||||
from compressed_tensors import CompressionFormat
|
||||
from compressed_tensors.quantization import QuantizationStrategy
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
||||
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
||||
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
||||
from sglang.srt.layers.quantization.utils import (
|
||||
@@ -20,6 +22,12 @@ from sglang.srt.layers.quantization.utils import (
|
||||
)
|
||||
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
||||
CompressedTensorsConfig,
|
||||
)
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_npu = is_npu()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
@@ -51,7 +59,7 @@ __all__ = [
|
||||
]
|
||||
|
||||
|
||||
class CompressedTensorsMoEMethod:
|
||||
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
||||
def __new__(cls, *args, **kwargs):
|
||||
if cls is CompressedTensorsMoEMethod:
|
||||
return super().__new__(cls)
|
||||
@@ -59,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
||||
|
||||
@staticmethod
|
||||
def get_moe_method(
|
||||
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
||||
quant_config: CompressedTensorsConfig,
|
||||
) -> "CompressedTensorsMoEMethod":
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -82,9 +90,7 @@ class CompressedTensorsMoEMethod:
|
||||
|
||||
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.quant_config = quant_config
|
||||
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
||||
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
||||
@@ -270,47 +276,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
@@ -327,9 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
|
||||
def __init__(
|
||||
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
||||
):
|
||||
def __init__(self, quant_config: CompressedTensorsConfig):
|
||||
self.quant_config = quant_config
|
||||
# TODO: @dsikka: refactor this to use schemes as other kernels
|
||||
# are supported + check if the layer is being ignored.
|
||||
@@ -628,43 +606,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
if expert_map is not None:
|
||||
raise NotImplementedError(
|
||||
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return torch.ops.vllm.fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -78,6 +78,7 @@ from sglang.srt.utils import (
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
||||
|
||||
_is_hip = is_hip()
|
||||
@@ -971,15 +972,8 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -987,26 +981,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
@@ -1032,8 +1011,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
ret = self.maybe_apply_hip_fused_experts(
|
||||
layer,
|
||||
x,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
topk_output,
|
||||
activation,
|
||||
no_combine,
|
||||
)
|
||||
@@ -1048,6 +1026,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
):
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
return cutlass_fused_experts_fp8(
|
||||
x,
|
||||
layer.w13_weight.transpose(1, 2),
|
||||
@@ -1076,8 +1055,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@@ -1101,11 +1079,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
activation: str = "silu",
|
||||
no_combine: bool = False,
|
||||
) -> Optional[torch.Tensor]:
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if _use_hip_int4:
|
||||
# TODO: add triton kernel and add check _use_aiter
|
||||
assert not no_combine, f"{no_combine=} is not supported."
|
||||
@@ -1397,14 +1375,8 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from fractions import Fraction
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
@@ -43,6 +43,9 @@ from sglang.srt.layers.quantization.utils import (
|
||||
unpack_cols,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
try:
|
||||
from vllm import _custom_ops as ops
|
||||
except ImportError:
|
||||
@@ -1057,42 +1060,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
# Delay the import to avoid circular dependency
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
assert (
|
||||
scoring_func == "softmax"
|
||||
), "Only softmax score func is supported for now."
|
||||
|
||||
# The input must currently be float16
|
||||
orig_dtype = x.dtype
|
||||
x = x.half()
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=e_score_correction_bias,
|
||||
)
|
||||
topk_weights, topk_ids, router_logits = topk_output
|
||||
|
||||
return fused_marlin_moe(
|
||||
x,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -31,6 +31,9 @@ from sglang.srt.layers.quantization.utils import (
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.utils import is_cuda, next_power_of_2
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
if is_cuda():
|
||||
from sgl_kernel import cutlass_scaled_fp4_mm, scaled_fp4_quant
|
||||
|
||||
@@ -402,15 +405,8 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -418,29 +414,12 @@ class ModelOptFp8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
@@ -961,15 +940,8 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -982,21 +954,6 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if self.enable_flashinfer_moe:
|
||||
assert (
|
||||
@@ -1004,6 +961,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
), "apply_router_weight_on_input is not supported for Flashinfer"
|
||||
# TRTLLM Cutlass moe takes in activations in BF16/Half/nvfp4 precision
|
||||
# and fp4 quantized weights loaded from the checkpoint
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
output = flashinfer_cutlass_fused_moe(
|
||||
x,
|
||||
topk_ids.to(torch.int),
|
||||
@@ -1029,6 +987,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
||||
|
||||
from sglang.srt.layers.moe.cutlass_moe import cutlass_moe_fp4
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
return cutlass_moe_fp4(
|
||||
a=x,
|
||||
a1_gscale=layer.w13_input_scale_quant,
|
||||
|
||||
@@ -2,8 +2,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
||||
@@ -20,6 +21,9 @@ from sglang.srt.utils import get_device_capability, set_weight_attrs
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
|
||||
def get_weight_perm(num_bits: int):
|
||||
perm_list: List[int] = []
|
||||
@@ -348,15 +352,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -365,22 +362,8 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
) -> torch.Tensor:
|
||||
# avoid circular import
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
assert activation == "silu", "Only SiLU activation is supported."
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
weight_bits = self.quant_config.weight_bits
|
||||
has_zp = self.quant_config.has_zp
|
||||
@@ -389,8 +372,7 @@ class MoeWNA16Method(FusedMoEMethodBase):
|
||||
x,
|
||||
layer.w13_qweight,
|
||||
layer.w2_qweight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
use_int4_w4a16=weight_bits == 4,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import Callable, List, Optional
|
||||
from typing import TYPE_CHECKING, Callable, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
@@ -21,6 +23,9 @@ from sglang.srt.utils import (
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
has_triton_kernels = importlib.util.find_spec("triton_kernels") is not None
|
||||
|
||||
|
||||
@@ -125,25 +130,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
super().__init__()
|
||||
self.use_triton_kernels = use_triton_kernels
|
||||
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
if torch.cuda.is_available():
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
|
||||
if has_triton_kernels:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
|
||||
triton_kernel_moe_forward,
|
||||
)
|
||||
else:
|
||||
triton_kernel_moe_forward = None
|
||||
else:
|
||||
fused_experts = None # type: ignore
|
||||
triton_kernel_moe_forward = None
|
||||
|
||||
self.moe_forward_native = moe_forward_native
|
||||
self.fused_experts = fused_experts
|
||||
self.triton_kernel_moe_forward = triton_kernel_moe_forward
|
||||
|
||||
def create_weights(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
@@ -201,34 +187,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
return self.forward(
|
||||
x=x,
|
||||
layer=layer,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
topk_output=topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
@@ -240,15 +210,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -257,33 +220,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor:
|
||||
|
||||
if self.use_triton_kernels:
|
||||
return self.triton_kernel_moe_forward(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
)
|
||||
# TODO(ch-wan): re-enable the Triton kernel
|
||||
raise NotImplementedError("The Triton kernel is temporarily disabled.")
|
||||
# return triton_kernel_moe_forward(
|
||||
# hidden_states=x,
|
||||
# w1=layer.w13_weight,
|
||||
# w2=layer.w2_weight,
|
||||
# gating_output=router_logits,
|
||||
# topk=top_k,
|
||||
# renormalize=renormalize,
|
||||
# )
|
||||
else:
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if _use_aiter:
|
||||
assert not no_combine, "unsupported"
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
if apply_router_weight_on_input:
|
||||
assert (
|
||||
topk_weights.dim() == 2
|
||||
@@ -296,7 +246,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
topk_weights = torch.ones_like(
|
||||
topk_weights, dtype=torch.float32
|
||||
) # topk_weights must be FP32 (float32)
|
||||
|
||||
return fused_moe(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -310,12 +259,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
),
|
||||
)
|
||||
else:
|
||||
return self.fused_experts(
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||
fused_experts,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace and not no_combine,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@@ -327,15 +279,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -344,30 +289,13 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
) -> torch.Tensor:
|
||||
assert activation == "silu", f"activation = {activation} is not supported."
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
if use_intel_amx_backend(layer) and not apply_router_weight_on_input:
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
from sglang.srt.layers.moe.topk import (
|
||||
apply_topk_weights_cpu,
|
||||
select_experts,
|
||||
)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
|
||||
return torch.ops.sgl_kernel.fused_experts_cpu(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
@@ -385,61 +313,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
True, # is_vnni
|
||||
)
|
||||
else:
|
||||
return self.moe_forward_native(
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_npu(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
return self.moe_forward_native(
|
||||
from sglang.srt.layers.moe.fused_moe_native import moe_forward_native
|
||||
|
||||
return moe_forward_native(
|
||||
layer,
|
||||
x,
|
||||
use_grouped_topk,
|
||||
top_k,
|
||||
router_logits,
|
||||
renormalize,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
inplace,
|
||||
no_combine,
|
||||
routed_scaling_factor,
|
||||
topk_output,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
inplace=inplace,
|
||||
no_combine=no_combine,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
def forward_tpu(self, *args, **kwargs) -> torch.Tensor:
|
||||
@@ -508,13 +417,7 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_output: TopKOutput,
|
||||
) -> torch.Tensor:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -25,6 +25,9 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
||||
)
|
||||
from sglang.srt.utils import set_weight_attrs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
_is_fp8_fnuz = is_fp8_fnuz()
|
||||
|
||||
|
||||
@@ -266,45 +269,23 @@ class W8A8FP8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
no_combine: bool = False,
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
return fused_experts(
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
activation=activation,
|
||||
use_fp8_w8a8=True,
|
||||
per_channel_quant=True,
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import importlib
|
||||
import sys
|
||||
from types import MappingProxyType
|
||||
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Mapping, Optional, Tuple, Union, cast
|
||||
|
||||
import torch
|
||||
from torch.nn.parameter import Parameter
|
||||
@@ -37,6 +37,9 @@ from sglang.srt.utils import (
|
||||
use_intel_amx_backend,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.layers.moe.topk import TopKOutput
|
||||
|
||||
_is_cuda = is_cuda()
|
||||
_is_cpu_amx_available = cpu_has_amx_support()
|
||||
_is_cpu = is_cpu()
|
||||
@@ -239,7 +242,7 @@ class W8A8Int8Config(QuantizationConfig):
|
||||
layer: torch.nn.Module,
|
||||
prefix: str,
|
||||
) -> Optional[QuantizeMethodBase]:
|
||||
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
||||
from sglang.srt.layers.linear import LinearBase
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
|
||||
if _is_npu:
|
||||
@@ -469,15 +472,8 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
num_fused_shared_experts: int = 0,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
correction_bias: Optional[torch.Tensor] = None,
|
||||
topk_output: TopKOutput,
|
||||
*,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
inplace: bool = True,
|
||||
@@ -485,26 +481,11 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
routed_scaling_factor: Optional[float] = None,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
# Expert selection
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
|
||||
if use_intel_amx_backend(layer):
|
||||
from sglang.srt.layers.moe.topk import apply_topk_weights_cpu
|
||||
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
x, topk_weights = apply_topk_weights_cpu(
|
||||
apply_router_weight_on_input, topk_weights, x
|
||||
)
|
||||
@@ -529,8 +510,7 @@ class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
||||
x,
|
||||
layer.w13_weight,
|
||||
layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
topk_output=topk_output,
|
||||
inplace=inplace,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
@@ -907,7 +887,7 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
layer: torch.nn.Module,
|
||||
num_experts: int,
|
||||
hidden_size: int,
|
||||
intermediate_size: List[int],
|
||||
intermediate_size: int,
|
||||
params_dtype: torch.dtype,
|
||||
**extra_weight_attrs,
|
||||
) -> None:
|
||||
@@ -984,52 +964,11 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
self,
|
||||
layer,
|
||||
x,
|
||||
router_logits,
|
||||
top_k,
|
||||
renormalize,
|
||||
use_grouped_topk,
|
||||
topk_group,
|
||||
num_expert_group,
|
||||
num_fused_shared_experts,
|
||||
custom_routing_function,
|
||||
correction_bias,
|
||||
activation,
|
||||
apply_router_weight_on_input,
|
||||
routed_scaling_factor,
|
||||
topk_output: TopKOutput,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
|
||||
global_num_experts = router_logits.shape[-1]
|
||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||
if global_num_experts == 256:
|
||||
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
bias=correction_bias,
|
||||
k_group=topk_group,
|
||||
group_count=num_expert_group,
|
||||
group_select_mode=1,
|
||||
renorm=0,
|
||||
norm_type=1,
|
||||
routed_scaling_factor=1,
|
||||
eps=float(1e-20),
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
num_fused_shared_experts=num_fused_shared_experts,
|
||||
custom_routing_function=custom_routing_function,
|
||||
correction_bias=correction_bias,
|
||||
torch_native=True,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
)
|
||||
topk_weights, topk_ids, _ = topk_output
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
return npu_fused_experts(
|
||||
@@ -1040,5 +979,5 @@ class NPU_W8A8MoEMethod(FusedMoEMethodBase):
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
top_k=topk_ids.shape[1],
|
||||
)
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -109,7 +110,10 @@ class DeepseekMoE(nn.Module):
|
||||
f"Tensor parallel size {self.tp_size} is greater than "
|
||||
f"the number of experts {self.n_routed_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=self.top_k,
|
||||
renormalize=config.norm_topk_prob,
|
||||
)
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
DeepseekMLP(
|
||||
@@ -170,13 +174,12 @@ class DeepseekMoE(nn.Module):
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = fused_moe.fused_moe(
|
||||
hidden_states,
|
||||
self.w1,
|
||||
self.w2,
|
||||
router_logits,
|
||||
self.top_k,
|
||||
renormalize=self.config.norm_topk_prob,
|
||||
w1=self.w1,
|
||||
w2=self.w2,
|
||||
topk_output=topk_output,
|
||||
inplace=True,
|
||||
)
|
||||
|
||||
|
||||
@@ -58,7 +58,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
@@ -303,6 +303,17 @@ class DeepseekV2MoE(nn.Module):
|
||||
config=config, prefix=add_prefix("gate", prefix), is_nextn=is_nextn
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.n_routed_experts
|
||||
+ self.num_fused_shared_experts
|
||||
@@ -311,13 +322,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
layer_id=self.layer_id,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
@@ -451,8 +456,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
with torch.cuda.stream(self.alt_stream):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
if not _is_cuda:
|
||||
final_hidden_states *= self.routed_scaling_factor
|
||||
@@ -473,8 +479,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
if not _is_cuda and not _use_aiter:
|
||||
# fused in biased_grouped_topk so we can skip here
|
||||
@@ -490,8 +497,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
) -> torch.Tensor:
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
fused_experts_out = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
hidden_states=hidden_states, topk_output=topk_output
|
||||
)
|
||||
|
||||
assert use_intel_amx_backend(
|
||||
@@ -549,17 +557,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits = self.gate(hidden_states)
|
||||
shared_output = self._forward_shared_experts(hidden_states)
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
@@ -649,17 +649,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=True,
|
||||
renormalize=self.renormalize,
|
||||
topk_group=self.topk_group,
|
||||
num_expert_group=self.num_expert_group,
|
||||
num_fused_shared_experts=self.num_fused_shared_experts,
|
||||
correction_bias=self.correction_bias,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
|
||||
@@ -15,6 +15,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -60,6 +61,11 @@ class GraniteMoeMoE(nn.Module):
|
||||
prefix=f"{prefix}.gate",
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
@@ -67,7 +73,6 @@ class GraniteMoeMoE(nn.Module):
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
reduce_results=True,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=f"{prefix}.experts",
|
||||
@@ -78,7 +83,8 @@ class GraniteMoeMoE(nn.Module):
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
|
||||
@@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.router import fused_moe_router_shim
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -108,6 +109,12 @@ class Grok1MoE(nn.Module):
|
||||
fused_moe_router_shim, self.router_logit_softcapping
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=False,
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
kwargs = {}
|
||||
if global_server_args_dict["enable_ep_moe"]:
|
||||
MoEImpl = EPMoE
|
||||
@@ -124,17 +131,16 @@ class Grok1MoE(nn.Module):
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
custom_routing_function=custom_routing_function,
|
||||
activation="gelu",
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
# need to assert self.gate.quant_method is unquantized
|
||||
return self.experts(hidden_states, self.gate.weight)
|
||||
topk_output = self.topk(hidden_states, self.gate.weight)
|
||||
return self.experts(hidden_states, topk_output)
|
||||
|
||||
|
||||
class Grok1Attention(nn.Module):
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -152,13 +153,16 @@ class HunYuanSparseMoeBlock(nn.Module):
|
||||
else config.moe_intermediate_size[layer_id]
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=True if top_k > 1 else False,
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=False,
|
||||
renormalize=True if top_k > 1 else False,
|
||||
quant_config=quant_config,
|
||||
)
|
||||
|
||||
@@ -195,9 +199,8 @@ class HunYuanSparseMoeBlock(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
if self.tp_size > 1:
|
||||
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -103,14 +104,17 @@ class Llama4MoE(nn.Module):
|
||||
prefix=add_prefix("router", prefix),
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=self.top_k,
|
||||
renormalize=False,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=config.num_local_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
hidden_size=config.hidden_size,
|
||||
custom_routing_function=Llama4MoE.custom_routing_function,
|
||||
intermediate_size=intermediate_size_moe,
|
||||
reduce_results=False,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
apply_router_weight_on_input=True,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
@@ -147,10 +151,8 @@ class Llama4MoE(nn.Module):
|
||||
# router_scores: [num_tokens, num_experts]
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
shared_out = self.shared_expert(hidden_states)
|
||||
routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
routed_out = self.experts(hidden_states, topk_output)
|
||||
return shared_out, routed_out
|
||||
|
||||
def _forward_core_shared_routed_overlap(self, hidden_states):
|
||||
@@ -163,10 +165,8 @@ class Llama4MoE(nn.Module):
|
||||
with self.device_module.stream(alt_stream):
|
||||
# router_scores: [num_tokens, num_experts]
|
||||
router_logits, _ = self.router(hidden_states)
|
||||
routed_out = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
routed_out = self.experts(hidden_states, topk_output)
|
||||
self.device_module.current_stream().wait_stream(alt_stream)
|
||||
|
||||
return shared_out, routed_out
|
||||
|
||||
@@ -37,6 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -86,6 +87,12 @@ class MixtralMoE(nn.Module):
|
||||
quant_config=None,
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=True,
|
||||
)
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
num_experts=num_experts,
|
||||
@@ -93,7 +100,6 @@ class MixtralMoE(nn.Module):
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
params_dtype=params_dtype,
|
||||
renormalize=True,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
@@ -105,7 +111,8 @@ class MixtralMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -76,13 +77,16 @@ class OlmoeMoE(nn.Module):
|
||||
prefix=add_prefix("gate", prefix),
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
tp_size=tp_size,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
@@ -94,9 +98,8 @@ class OlmoeMoE(nn.Module):
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sglang.srt.layers.linear import (
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
@@ -200,15 +201,19 @@ class PhiMoE(nn.Module):
|
||||
quant_config=None,
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=top_k,
|
||||
renormalize=False,
|
||||
custom_routing_function=phimoe_routing_function,
|
||||
)
|
||||
|
||||
self.experts = FusedMoE(
|
||||
num_experts=num_experts,
|
||||
top_k=top_k,
|
||||
hidden_size=hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
reduce_results=True,
|
||||
renormalize=False,
|
||||
quant_config=quant_config,
|
||||
custom_routing_function=phimoe_routing_function,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
)
|
||||
|
||||
@@ -219,7 +224,8 @@ class PhiMoE(nn.Module):
|
||||
orig_shape = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, self.hidden_size)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(hidden_states, router_logits)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
return final_hidden_states.view(orig_shape)
|
||||
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE, get_moe_impl_class
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -134,13 +135,17 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=config.norm_topk_prob,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
layer_id=self.layer_id,
|
||||
num_experts=config.num_experts,
|
||||
top_k=config.num_experts_per_tok,
|
||||
num_experts=config.num_experts,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
# Additional args for FusedMoE
|
||||
@@ -189,9 +194,8 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
@@ -56,8 +56,7 @@ from sglang.srt.layers.linear import (
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
||||
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.moe.topk import TopK
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.rotary_embedding import get_rope
|
||||
@@ -102,6 +101,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
f"the number of experts {config.num_experts}."
|
||||
)
|
||||
|
||||
self.topk = TopK(
|
||||
top_k=config.num_experts_per_tok,
|
||||
renormalize=config.norm_topk_prob,
|
||||
use_grouped_topk=False,
|
||||
)
|
||||
|
||||
self.experts = get_moe_impl_class()(
|
||||
num_experts=config.num_experts
|
||||
+ global_server_args_dict["ep_num_redundant_experts"],
|
||||
@@ -109,7 +114,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
layer_id=layer_id,
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=config.moe_intermediate_size,
|
||||
renormalize=config.norm_topk_prob,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("experts", prefix),
|
||||
**(
|
||||
@@ -143,7 +147,6 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
config.num_experts + global_server_args_dict["ep_num_redundant_experts"]
|
||||
)
|
||||
self.top_k = config.num_experts_per_tok
|
||||
self.renormalize = config.norm_topk_prob
|
||||
|
||||
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
||||
group=parallel_state.get_tp_group().device_group,
|
||||
@@ -180,9 +183,8 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states, router_logits=router_logits
|
||||
)
|
||||
topk_output = self.topk(hidden_states, router_logits)
|
||||
final_hidden_states = self.experts(hidden_states, topk_output)
|
||||
if self.tp_size > 1:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
||||
|
||||
@@ -195,13 +197,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
topk_weights, topk_idx = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
topk_weights, topk_idx, _ = self.topk(
|
||||
hidden_states,
|
||||
router_logits,
|
||||
num_token_non_padded=forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
@@ -267,12 +265,9 @@ class Qwen3MoeSparseMoeBlock(nn.Module):
|
||||
with get_global_expert_distribution_recorder().with_current_layer(
|
||||
self.layer_id
|
||||
):
|
||||
state.topk_weights_local, state.topk_idx_local = select_experts(
|
||||
state.topk_weights_local, state.topk_idx_local, _ = self.topk(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=self.top_k,
|
||||
use_grouped_topk=False,
|
||||
renormalize=self.renormalize,
|
||||
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
||||
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
||||
layer_id=self.layer_id,
|
||||
|
||||
@@ -6,6 +6,7 @@ import torch
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||
from sglang.srt.layers.moe.topk import select_experts
|
||||
from sglang.srt.layers.quantization.fp8_kernel import (
|
||||
per_tensor_quant_mla_fp8,
|
||||
per_token_group_quant_fp8,
|
||||
@@ -497,13 +498,17 @@ class TestW8A8BlockFP8FusedMoE(CustomTestCase):
|
||||
score = torch.randn((M, E), dtype=dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
topk_output = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
renormalize=False,
|
||||
)
|
||||
out = fused_moe(
|
||||
a,
|
||||
w1,
|
||||
w2,
|
||||
score,
|
||||
topk,
|
||||
renormalize=False,
|
||||
topk_output,
|
||||
use_fp8_w8a8=True,
|
||||
w1_scale=w1_s,
|
||||
w2_scale=w2_s,
|
||||
|
||||
@@ -40,7 +40,7 @@ def ep_moe(
|
||||
block_shape: Optional[List[int]] = None,
|
||||
):
|
||||
use_blockwise_fp8 = block_shape is not None
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
|
||||
@@ -100,12 +100,10 @@ def test_cutlass_w4a8_moe(M, N, K, E, ep_size, topk, group_size, dtype):
|
||||
s_strides2 = c_strides2
|
||||
|
||||
score = torch.randn((M, E), dtype=dtype, device=device)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
expert_map = torch.arange(E, dtype=torch.int32, device=device)
|
||||
expert_map[local_e:] = E
|
||||
|
||||
@@ -159,12 +159,10 @@ def test_cutlass_fp4_moe_no_graph(
|
||||
|
||||
score = torch.randn((m, e), device="cuda", dtype=dtype)
|
||||
|
||||
topk_weights, topk_ids = select_experts(
|
||||
topk_weights, topk_ids, _ = select_experts(
|
||||
hidden_states=a,
|
||||
router_logits=score,
|
||||
top_k=topk,
|
||||
use_grouped_topk=False,
|
||||
renormalize=False,
|
||||
)
|
||||
|
||||
a1_gs = torch.ones((e,), device="cuda", dtype=torch.float32)
|
||||
|
||||
Reference in New Issue
Block a user