Reorg moe code (#2563)
This commit is contained in:
@@ -5,7 +5,9 @@ import triton
|
|||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_triton
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe as fused_moe_triton,
|
||||||
|
)
|
||||||
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
from sglang.srt.model_executor.cuda_graph_runner import set_torch_compile_config
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -5,7 +5,9 @@ import triton
|
|||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
from vllm.model_executor.layers.fused_moe.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe as fused_moe_sglang
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
|
fused_moe as fused_moe_sglang,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_model_config(model_name: str, tp_size: int):
|
def get_model_config(model_name: str, tp_size: int):
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import triton
|
|||||||
from ray.experimental.tqdm_ray import tqdm
|
from ray.experimental.tqdm_ray import tqdm
|
||||||
from transformers import AutoConfig
|
from transformers import AutoConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
fused_moe,
|
fused_moe,
|
||||||
get_config_dtype_str,
|
get_config_dtype_str,
|
||||||
get_config_file_name,
|
get_config_file_name,
|
||||||
@@ -97,7 +97,7 @@ def benchmark_config(
|
|||||||
input_gating.copy_(gating_output[i])
|
input_gating.copy_(gating_output[i])
|
||||||
|
|
||||||
def run():
|
def run():
|
||||||
from sglang.srt.layers.fused_moe_triton import override_config
|
from sglang.srt.layers.moe.fused_moe_triton import override_config
|
||||||
|
|
||||||
with override_config(config):
|
with override_config(config):
|
||||||
fused_moe(
|
fused_moe(
|
||||||
|
|||||||
@@ -1,133 +0,0 @@
|
|||||||
"""
|
|
||||||
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
|
||||||
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
|
||||||
"""
|
|
||||||
|
|
||||||
from typing import Callable, Optional
|
|
||||||
|
|
||||||
import torch
|
|
||||||
from torch.nn import functional as F
|
|
||||||
|
|
||||||
|
|
||||||
def fused_topk_native(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
):
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
||||||
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
|
||||||
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
# This is used by the Deepseek-V2 model
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def select_experts_native(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
):
|
|
||||||
# DeekSeekv2 uses grouped_top_k
|
|
||||||
if use_grouped_topk:
|
|
||||||
assert topk_group is not None
|
|
||||||
assert num_expert_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = fused_topk_native(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
)
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
def fused_moe_forward_native(
|
|
||||||
layer: torch.nn.Module,
|
|
||||||
x: torch.Tensor,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
top_k: int,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
) -> torch.Tensor:
|
|
||||||
|
|
||||||
if use_grouped_topk:
|
|
||||||
assert num_expert_group is not None and topk_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
x,
|
|
||||||
router_logits,
|
|
||||||
top_k,
|
|
||||||
renormalize,
|
|
||||||
num_expert_group,
|
|
||||||
topk_group,
|
|
||||||
)
|
|
||||||
elif custom_routing_function is None:
|
|
||||||
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
|
||||||
x, router_logits, top_k, renormalize
|
|
||||||
)
|
|
||||||
|
|
||||||
w13_weights = layer.w13_weight[topk_ids]
|
|
||||||
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
|
||||||
w2_weights = layer.w2_weight[topk_ids]
|
|
||||||
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
|
||||||
x1 = F.silu(x1)
|
|
||||||
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
|
||||||
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
|
||||||
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
|
||||||
@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
|
|||||||
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
||||||
|
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
from sglang.srt.layers.ep_moe.kernels import (
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
||||||
grouped_gemm_triton,
|
grouped_gemm_triton,
|
||||||
post_reorder_triton_kernel,
|
post_reorder_triton_kernel,
|
||||||
pre_reorder_triton_kernel,
|
pre_reorder_triton_kernel,
|
||||||
run_moe_ep_preproess,
|
run_moe_ep_preproess,
|
||||||
silu_and_mul_triton_kernel,
|
silu_and_mul_triton_kernel,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
quant_config: Optional[QuantizationConfig] = None,
|
quant_config: Optional[QuantizationConfig] = None,
|
||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
|
|||||||
assert num_expert_group is not None and topk_group is not None
|
assert num_expert_group is not None and topk_group is not None
|
||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
|
self.correction_bias = correction_bias
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
|
||||||
@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
|
|||||||
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
||||||
)
|
)
|
||||||
|
|
||||||
topk_weights, topk_ids = self.select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits,
|
router_logits=router_logits,
|
||||||
self.top_k,
|
top_k=self.top_k,
|
||||||
self.renormalize,
|
use_grouped_topk=self.use_grouped_topk,
|
||||||
self.topk_group,
|
renormalize=self.renormalize,
|
||||||
self.num_expert_group,
|
topk_group=self.topk_group,
|
||||||
|
num_expert_group=self.num_expert_group,
|
||||||
|
correction_bias=self.correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
||||||
@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def select_experts(
|
|
||||||
self,
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
):
|
|
||||||
if self.use_grouped_topk:
|
|
||||||
assert topk_group is not None
|
|
||||||
assert num_expert_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = fused_topk(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
)
|
|
||||||
return topk_weights, topk_ids.to(torch.int32)
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def make_expert_params_mapping(
|
def make_expert_params_mapping(
|
||||||
cls,
|
cls,
|
||||||
46
python/sglang/srt/layers/moe/fused_moe_native.py
Normal file
46
python/sglang/srt/layers/moe/fused_moe_native.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""
|
||||||
|
Torch-native implementation for FusedMoE. This is used for torch.compile.
|
||||||
|
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch.nn import functional as F
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
|
|
||||||
|
def fused_moe_forward_native(
|
||||||
|
layer: torch.nn.Module,
|
||||||
|
x: torch.Tensor,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
top_k: int,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
topk_weights, topk_ids = select_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
router_logits=router_logits,
|
||||||
|
use_grouped_topk=use_grouped_topk,
|
||||||
|
top_k=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
topk_group=topk_group,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
torch_native=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
w13_weights = layer.w13_weight[topk_ids]
|
||||||
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
||||||
|
w2_weights = layer.w2_weight[topk_ids]
|
||||||
|
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
||||||
|
x1 = F.silu(x1)
|
||||||
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
||||||
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
||||||
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|
||||||
@@ -1,14 +1,12 @@
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
|
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
|
||||||
fused_experts,
|
fused_experts,
|
||||||
fused_topk,
|
|
||||||
get_config_file_name,
|
get_config_file_name,
|
||||||
grouped_topk,
|
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import (
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
||||||
FusedMoE,
|
FusedMoE,
|
||||||
FusedMoEMethodBase,
|
FusedMoEMethodBase,
|
||||||
FusedMoeWeightScaleSupported,
|
FusedMoeWeightScaleSupported,
|
||||||
@@ -37,8 +35,6 @@ __all__ = [
|
|||||||
"override_config",
|
"override_config",
|
||||||
"get_config",
|
"get_config",
|
||||||
"fused_moe",
|
"fused_moe",
|
||||||
"fused_topk",
|
|
||||||
"fused_experts",
|
"fused_experts",
|
||||||
"get_config_file_name",
|
"get_config_file_name",
|
||||||
"grouped_topk",
|
|
||||||
]
|
]
|
||||||
@@ -13,6 +13,7 @@ import triton
|
|||||||
import triton.language as tl
|
import triton.language as tl
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
from sglang.srt.utils import direct_register_custom_op, get_device_name
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -415,7 +416,7 @@ def try_get_optimal_moe_config(
|
|||||||
M: int,
|
M: int,
|
||||||
is_marlin: bool = False,
|
is_marlin: bool = False,
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.fused_moe_triton import get_config
|
from sglang.srt.layers.moe.fused_moe_triton import get_config
|
||||||
|
|
||||||
override_config = get_config()
|
override_config = get_config()
|
||||||
if override_config:
|
if override_config:
|
||||||
@@ -435,74 +436,6 @@ def try_get_optimal_moe_config(
|
|||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
def fused_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
):
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
|
|
||||||
M, _ = hidden_states.shape
|
|
||||||
|
|
||||||
topk_weights = torch.empty(
|
|
||||||
M, topk, dtype=torch.float32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
|
||||||
token_expert_indicies = torch.empty(
|
|
||||||
M, topk, dtype=torch.int32, device=hidden_states.device
|
|
||||||
)
|
|
||||||
|
|
||||||
ops.topk_softmax(
|
|
||||||
topk_weights,
|
|
||||||
topk_ids,
|
|
||||||
token_expert_indicies,
|
|
||||||
gating_output.float(), # TODO(woosuk): Optimize this.
|
|
||||||
)
|
|
||||||
del token_expert_indicies # Not used. Will be used in the future.
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
|
|
||||||
# This is used by the Deepseek-V2 model
|
|
||||||
def grouped_topk(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
gating_output: torch.Tensor,
|
|
||||||
topk: int,
|
|
||||||
renormalize: bool,
|
|
||||||
num_expert_group: int = 0,
|
|
||||||
topk_group: int = 0,
|
|
||||||
):
|
|
||||||
|
|
||||||
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
|
||||||
|
|
||||||
scores = torch.softmax(gating_output, dim=-1)
|
|
||||||
num_token = scores.shape[0]
|
|
||||||
group_scores = (
|
|
||||||
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
|
||||||
) # [n, n_group]
|
|
||||||
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
|
||||||
1
|
|
||||||
] # [n, top_k_group]
|
|
||||||
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
|
||||||
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
|
||||||
score_mask = (
|
|
||||||
group_mask.unsqueeze(-1)
|
|
||||||
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
|
||||||
.reshape(num_token, -1)
|
|
||||||
) # [n, e]
|
|
||||||
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
|
||||||
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
|
||||||
|
|
||||||
if renormalize:
|
|
||||||
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
|
||||||
|
|
||||||
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_dtype_str(
|
def get_config_dtype_str(
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
use_int8_w8a16: Optional[bool] = False,
|
use_int8_w8a16: Optional[bool] = False,
|
||||||
@@ -869,23 +802,15 @@ def fused_moe(
|
|||||||
# Check constraints.
|
# Check constraints.
|
||||||
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
|
||||||
|
|
||||||
if use_grouped_topk:
|
topk_weights, topk_ids = select_experts(
|
||||||
assert num_expert_group is not None and topk_group is not None
|
hidden_states=hidden_states,
|
||||||
topk_weights, topk_ids = grouped_topk(
|
router_logits=gating_output,
|
||||||
hidden_states,
|
use_grouped_topk=use_grouped_topk,
|
||||||
gating_output,
|
top_k=topk,
|
||||||
topk,
|
renormalize=renormalize,
|
||||||
renormalize,
|
topk_group=topk_group,
|
||||||
num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
topk_group,
|
custom_routing_function=custom_routing_function,
|
||||||
)
|
|
||||||
elif custom_routing_function is None:
|
|
||||||
topk_weights, topk_ids = fused_topk(
|
|
||||||
hidden_states, gating_output, topk, renormalize
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
|
||||||
hidden_states, gating_output, topk, renormalize
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -13,6 +13,7 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.custom_op_util import register_custom_op
|
from sglang.srt.layers.custom_op_util import register_custom_op
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|||||||
from sglang.srt.utils import set_weight_attrs
|
from sglang.srt.utils import set_weight_attrs
|
||||||
|
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
else:
|
else:
|
||||||
fused_experts = None # type: ignore
|
fused_experts = None # type: ignore
|
||||||
|
|
||||||
@@ -106,6 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.forward(
|
return self.forward(
|
||||||
x=x,
|
x=x,
|
||||||
@@ -117,6 +119,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward_cuda(
|
def forward_cuda(
|
||||||
@@ -130,8 +133,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
@@ -140,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
@@ -197,6 +202,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
tp_size: Optional[int] = None,
|
tp_size: Optional[int] = None,
|
||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@@ -217,6 +223,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
self.num_expert_group = num_expert_group
|
self.num_expert_group = num_expert_group
|
||||||
self.topk_group = topk_group
|
self.topk_group = topk_group
|
||||||
self.custom_routing_function = custom_routing_function
|
self.custom_routing_function = custom_routing_function
|
||||||
|
self.correction_bias = correction_bias
|
||||||
|
|
||||||
if quant_config is None:
|
if quant_config is None:
|
||||||
self.quant_method: Optional[QuantizeMethodBase] = (
|
self.quant_method: Optional[QuantizeMethodBase] = (
|
||||||
@@ -503,51 +510,6 @@ class FusedMoE(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def select_experts(
|
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
router_logits: torch.Tensor,
|
|
||||||
top_k: int,
|
|
||||||
use_grouped_topk: bool,
|
|
||||||
renormalize: bool,
|
|
||||||
topk_group: Optional[int] = None,
|
|
||||||
num_expert_group: Optional[int] = None,
|
|
||||||
custom_routing_function: Optional[Callable] = None,
|
|
||||||
):
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import (
|
|
||||||
fused_topk,
|
|
||||||
grouped_topk,
|
|
||||||
)
|
|
||||||
|
|
||||||
# DeekSeekv2 uses grouped_top_k
|
|
||||||
if use_grouped_topk:
|
|
||||||
assert topk_group is not None
|
|
||||||
assert num_expert_group is not None
|
|
||||||
topk_weights, topk_ids = grouped_topk(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
num_expert_group=num_expert_group,
|
|
||||||
topk_group=topk_group,
|
|
||||||
)
|
|
||||||
elif custom_routing_function is None:
|
|
||||||
topk_weights, topk_ids = fused_topk(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
topk_weights, topk_ids = custom_routing_function(
|
|
||||||
hidden_states=hidden_states,
|
|
||||||
gating_output=router_logits,
|
|
||||||
topk=top_k,
|
|
||||||
renormalize=renormalize,
|
|
||||||
)
|
|
||||||
|
|
||||||
return topk_weights, topk_ids
|
|
||||||
|
|
||||||
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
@@ -562,6 +524,7 @@ class FusedMoE(torch.nn.Module):
|
|||||||
topk_group=self.topk_group,
|
topk_group=self.topk_group,
|
||||||
num_expert_group=self.num_expert_group,
|
num_expert_group=self.num_expert_group,
|
||||||
custom_routing_function=self.custom_routing_function,
|
custom_routing_function=self.custom_routing_function,
|
||||||
|
correction_bias=self.correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.reduce_results and self.tp_size > 1:
|
if self.reduce_results and self.tp_size > 1:
|
||||||
191
python/sglang/srt/layers/moe/topk.py
Normal file
191
python/sglang/srt/layers/moe/topk.py
Normal file
@@ -0,0 +1,191 @@
|
|||||||
|
from typing import Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk_native(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
topk_weights = F.softmax(gating_output.float(), dim=-1)
|
||||||
|
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
def fused_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
):
|
||||||
|
from vllm import _custom_ops as ops
|
||||||
|
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
M, _ = hidden_states.shape
|
||||||
|
|
||||||
|
topk_weights = torch.empty(
|
||||||
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
||||||
|
token_expert_indicies = torch.empty(
|
||||||
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
||||||
|
)
|
||||||
|
|
||||||
|
ops.topk_softmax(
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
token_expert_indicies,
|
||||||
|
gating_output.float(),
|
||||||
|
)
|
||||||
|
del token_expert_indicies
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
|
|
||||||
|
|
||||||
|
# This is used by the Deepseek-V2 model
|
||||||
|
def grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
scores = torch.softmax(gating_output, dim=-1)
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
group_scores = (
|
||||||
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def biased_grouped_topk(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
gating_output: torch.Tensor,
|
||||||
|
correction_bias: torch.Tensor,
|
||||||
|
topk: int,
|
||||||
|
renormalize: bool,
|
||||||
|
num_expert_group: int = 0,
|
||||||
|
topk_group: int = 0,
|
||||||
|
):
|
||||||
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
||||||
|
|
||||||
|
scores = gating_output.sigmoid()
|
||||||
|
num_token = scores.shape[0]
|
||||||
|
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
|
||||||
|
group_scores = (
|
||||||
|
scores_for_choice.view(num_token, num_expert_group, -1)
|
||||||
|
.topk(2, dim=-1)[0]
|
||||||
|
.sum(dim=-1)
|
||||||
|
) # [n, n_group]
|
||||||
|
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
|
||||||
|
1
|
||||||
|
] # [n, top_k_group]
|
||||||
|
group_mask = torch.zeros_like(group_scores) # [n, n_group]
|
||||||
|
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
|
||||||
|
score_mask = (
|
||||||
|
group_mask.unsqueeze(-1)
|
||||||
|
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
|
||||||
|
.reshape(num_token, -1)
|
||||||
|
) # [n, e]
|
||||||
|
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
|
||||||
|
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
|
||||||
|
topk_weights = scores.gather(1, topk_ids)
|
||||||
|
|
||||||
|
if renormalize:
|
||||||
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
||||||
|
|
||||||
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
||||||
|
|
||||||
|
|
||||||
|
def select_experts(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
use_grouped_topk: bool,
|
||||||
|
renormalize: bool,
|
||||||
|
topk_group: Optional[int] = None,
|
||||||
|
num_expert_group: Optional[int] = None,
|
||||||
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
|
torch_native: bool = False,
|
||||||
|
):
|
||||||
|
# DeekSeekv2 uses grouped_top_k
|
||||||
|
if use_grouped_topk:
|
||||||
|
assert topk_group is not None
|
||||||
|
assert num_expert_group is not None
|
||||||
|
if correction_bias is None:
|
||||||
|
topk_weights, topk_ids = grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = biased_grouped_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
correction_bias=correction_bias,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
num_expert_group=num_expert_group,
|
||||||
|
topk_group=topk_group,
|
||||||
|
)
|
||||||
|
elif torch_native:
|
||||||
|
topk_weights, topk_ids = fused_topk_native(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
elif custom_routing_function is None:
|
||||||
|
topk_weights, topk_ids = fused_topk(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
topk_weights, topk_ids = custom_routing_function(
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
gating_output=router_logits,
|
||||||
|
topk=top_k,
|
||||||
|
renormalize=renormalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
return topk_weights, topk_ids
|
||||||
@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
|
|||||||
is_layer_skipped,
|
is_layer_skipped,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
|
||||||
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
from sglang.srt.layers.linear import UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
|
|||||||
GPTQMarlinMoEMethod,
|
GPTQMarlinMoEMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return GPTQMarlinLinearMethod(self)
|
return GPTQMarlinLinearMethod(self)
|
||||||
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
|
|||||||
AWQMoEMethod,
|
AWQMoEMethod,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
return AWQMarlinLinearMethod(self)
|
return AWQMarlinLinearMethod(self)
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
|
|
||||||
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
|
||||||
from sglang.srt.layers.quantization.base_config import (
|
from sglang.srt.layers.quantization.base_config import (
|
||||||
QuantizationConfig,
|
QuantizationConfig,
|
||||||
QuantizeMethodBase,
|
QuantizeMethodBase,
|
||||||
@@ -98,7 +98,7 @@ class Fp8Config(QuantizationConfig):
|
|||||||
) -> Optional["QuantizeMethodBase"]:
|
) -> Optional["QuantizeMethodBase"]:
|
||||||
from vllm.attention.layer import Attention # Avoid circular import
|
from vllm.attention.layer import Attention # Avoid circular import
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
|
|
||||||
if isinstance(layer, LinearBase):
|
if isinstance(layer, LinearBase):
|
||||||
if is_layer_skipped(prefix, self.ignored_layers):
|
if is_layer_skipped(prefix, self.ignored_layers):
|
||||||
@@ -320,7 +320,7 @@ class Fp8MoEMethod:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __new__(cls, *args, **kwargs):
|
def __new__(cls, *args, **kwargs):
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
||||||
|
|
||||||
if not hasattr(cls, "_initialized"):
|
if not hasattr(cls, "_initialized"):
|
||||||
original_init = cls.__init__
|
original_init = cls.__init__
|
||||||
@@ -349,7 +349,7 @@ class Fp8MoEMethod:
|
|||||||
params_dtype: torch.dtype,
|
params_dtype: torch.dtype,
|
||||||
**extra_weight_attrs,
|
**extra_weight_attrs,
|
||||||
):
|
):
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
||||||
|
|
||||||
if self.quant_config.is_checkpoint_fp8_serialized:
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
||||||
params_dtype = torch.float8_e4m3fn
|
params_dtype = torch.float8_e4m3fn
|
||||||
@@ -566,12 +566,14 @@ class Fp8MoEMethod:
|
|||||||
topk_group: Optional[int] = None,
|
topk_group: Optional[int] = None,
|
||||||
num_expert_group: Optional[int] = None,
|
num_expert_group: Optional[int] = None,
|
||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
|
correction_bias: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
||||||
|
from sglang.srt.layers.moe.topk import select_experts
|
||||||
|
|
||||||
# Expert selection
|
# Expert selection
|
||||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
topk_weights, topk_ids = select_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
use_grouped_topk=use_grouped_topk,
|
use_grouped_topk=use_grouped_topk,
|
||||||
@@ -580,6 +582,7 @@ class Fp8MoEMethod:
|
|||||||
topk_group=topk_group,
|
topk_group=topk_group,
|
||||||
num_expert_group=num_expert_group,
|
num_expert_group=num_expert_group,
|
||||||
custom_routing_function=custom_routing_function,
|
custom_routing_function=custom_routing_function,
|
||||||
|
correction_bias=correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expert fusion with FP8 quantization
|
# Expert fusion with FP8 quantization
|
||||||
|
|||||||
@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|||||||
from vllm.distributed.parallel_state import graph_capture
|
from vllm.distributed.parallel_state import graph_capture
|
||||||
from vllm.model_executor.custom_op import CustomOp
|
from vllm.model_executor.custom_op import CustomOp
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
|
|
||||||
from sglang.srt.layers.logits_processor import (
|
from sglang.srt.layers.logits_processor import (
|
||||||
LogitsMetadata,
|
LogitsMetadata,
|
||||||
LogitsProcessor,
|
LogitsProcessor,
|
||||||
LogitsProcessorOutput,
|
LogitsProcessorOutput,
|
||||||
)
|
)
|
||||||
|
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
|
||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
||||||
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather
|
||||||
|
|
||||||
|
|||||||
@@ -27,13 +27,13 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
ReplicatedLinear,
|
ReplicatedLinear,
|
||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from transformers import PretrainedConfig
|
from transformers import PretrainedConfig
|
||||||
from vllm import _custom_ops as ops
|
from vllm import _custom_ops as ops
|
||||||
@@ -31,8 +32,6 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
ColumnParallelLinear,
|
ColumnParallelLinear,
|
||||||
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
|
|||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class MoEGate(nn.Module):
|
||||||
|
def __init__(self, config):
|
||||||
|
super().__init__()
|
||||||
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||||
|
)
|
||||||
|
if config.topk_method == "noaux_tc":
|
||||||
|
self.e_score_correction_bias = nn.Parameter(
|
||||||
|
torch.empty((config.n_routed_experts))
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.e_score_correction_bias = None
|
||||||
|
|
||||||
|
def forward(self, hidden_states):
|
||||||
|
logits = F.linear(hidden_states, self.weight, None)
|
||||||
|
return logits
|
||||||
|
|
||||||
|
|
||||||
class DeepseekV2MoE(nn.Module):
|
class DeepseekV2MoE(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
"Only silu is supported for now."
|
"Only silu is supported for now."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
self.gate = MoEGate(config=config)
|
||||||
|
|
||||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||||
self.experts = MoEImpl(
|
self.experts = MoEImpl(
|
||||||
num_experts=config.n_routed_experts,
|
num_experts=config.n_routed_experts,
|
||||||
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
use_grouped_topk=True,
|
use_grouped_topk=True,
|
||||||
num_expert_group=config.n_group,
|
num_expert_group=config.n_group,
|
||||||
topk_group=config.topk_group,
|
topk_group=config.topk_group,
|
||||||
|
correction_bias=self.gate.e_score_correction_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gate = ReplicatedLinear(
|
|
||||||
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
|
|
||||||
)
|
|
||||||
if config.n_shared_experts is not None:
|
if config.n_shared_experts is not None:
|
||||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||||
self.shared_experts = DeepseekV2MLP(
|
self.shared_experts = DeepseekV2MLP(
|
||||||
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
|
|||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
shared_output = self.shared_experts(hidden_states)
|
shared_output = self.shared_experts(hidden_states)
|
||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits = self.gate(hidden_states)
|
||||||
final_hidden_states = (
|
final_hidden_states = (
|
||||||
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
||||||
* self.routed_scaling_factor
|
* self.routed_scaling_factor
|
||||||
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
)
|
)
|
||||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
if rope_scaling:
|
||||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||||
|
|
||||||
self.rotary_emb = get_rope(
|
self.rotary_emb = get_rope(
|
||||||
qk_rope_head_dim,
|
qk_rope_head_dim,
|
||||||
rotary_dim=qk_rope_head_dim,
|
rotary_dim=qk_rope_head_dim,
|
||||||
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|||||||
scaling_factor = rope_scaling["factor"]
|
scaling_factor = rope_scaling["factor"]
|
||||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||||
self.scaling = self.scaling * mscale * mscale
|
self.scaling = self.scaling * mscale * mscale
|
||||||
|
else:
|
||||||
|
self.rotary_emb.forward = self.rotary_emb.forward_native
|
||||||
|
|
||||||
self.attn_mqa = RadixAttention(
|
self.attn_mqa = RadixAttention(
|
||||||
self.num_local_heads,
|
self.num_local_heads,
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import GeluAndMul
|
from sglang.srt.layers.activation import GeluAndMul
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -27,8 +27,6 @@ from vllm.distributed import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
QKVParallelLinear,
|
QKVParallelLinear,
|
||||||
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import (
|
from sglang.srt.layers.linear import (
|
||||||
MergedColumnParallelLinear,
|
MergedColumnParallelLinear,
|
||||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|||||||
RowParallelLinear,
|
RowParallelLinear,
|
||||||
)
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
|
|||||||
)
|
)
|
||||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||||
|
|
||||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
from sglang.srt.layers.radix_attention import RadixAttention
|
from sglang.srt.layers.radix_attention import RadixAttention
|
||||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ import torch
|
|||||||
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
|
from vllm.model_executor.layers.fused_moe import fused_moe as fused_moe_vllm
|
||||||
|
|
||||||
from sglang.srt.layers.activation import SiluAndMul
|
from sglang.srt.layers.activation import SiluAndMul
|
||||||
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_moe
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_moe
|
||||||
|
|
||||||
|
|
||||||
class TestFusedMOE(unittest.TestCase):
|
class TestFusedMOE(unittest.TestCase):
|
||||||
|
|||||||
Reference in New Issue
Block a user