Reorg moe code (#2563)

This commit is contained in:
Ke Bao
2024-12-24 01:10:22 +08:00
committed by GitHub
parent 23e5e50fd5
commit e835a50021
88 changed files with 338 additions and 344 deletions

View File

@@ -1,133 +0,0 @@
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch
from torch.nn import functional as F
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def select_experts_native(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
) -> torch.Tensor:
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
x,
router_logits,
top_k,
renormalize,
num_expert_group,
topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
else:
topk_weights, topk_ids = custom_routing_function(
x, router_logits, top_k, renormalize
)
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))

View File

@@ -12,15 +12,15 @@ from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.ep_moe.kernels import (
from sglang.srt.layers.moe.ep_moe.kernels import (
grouped_gemm_triton,
post_reorder_triton_kernel,
pre_reorder_triton_kernel,
run_moe_ep_preproess,
silu_and_mul_triton_kernel,
)
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_topk, grouped_topk
from sglang.srt.layers.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -113,6 +113,7 @@ class EPMoE(torch.nn.Module):
quant_config: Optional[QuantizationConfig] = None,
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()
@@ -138,6 +139,7 @@ class EPMoE(torch.nn.Module):
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -170,13 +172,15 @@ class EPMoE(torch.nn.Module):
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
)
topk_weights, topk_ids = self.select_experts(
hidden_states,
router_logits,
self.top_k,
self.renormalize,
self.topk_group,
self.num_expert_group,
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -297,35 +301,6 @@ class EPMoE(torch.nn.Module):
)
return output
def select_experts(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
):
if self.use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids.to(torch.int32)
@classmethod
def make_expert_params_mapping(
cls,

View File

@@ -0,0 +1,46 @@
"""
Torch-native implementation for FusedMoE. This is used for torch.compile.
It is based on https://github.com/pytorch-labs/gpt-fast/blob/32971d3129541c5bfb4f715abc33d1c5f408d204/mixtral-moe/model.py#L204
"""
from typing import Callable, Optional
import torch
from torch.nn import functional as F
from sglang.srt.layers.moe.topk import select_experts
def fused_moe_forward_native(
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
)
w13_weights = layer.w13_weight[topk_ids]
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
w2_weights = layer.w2_weight[topk_ids]
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
x1 = F.silu(x1)
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))

View File

@@ -1,14 +1,12 @@
from contextlib import contextmanager
from typing import Any, Dict, Optional
import sglang.srt.layers.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.fused_moe_triton.fused_moe import (
import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
fused_experts,
fused_topk,
get_config_file_name,
grouped_topk,
)
from sglang.srt.layers.fused_moe_triton.layer import (
from sglang.srt.layers.moe.fused_moe_triton.layer import (
FusedMoE,
FusedMoEMethodBase,
FusedMoeWeightScaleSupported,
@@ -37,8 +35,6 @@ __all__ = [
"override_config",
"get_config",
"fused_moe",
"fused_topk",
"fused_experts",
"get_config_file_name",
"grouped_topk",
]

View File

@@ -13,6 +13,7 @@ import triton
import triton.language as tl
from vllm import _custom_ops as ops
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.utils import direct_register_custom_op, get_device_name
logger = logging.getLogger(__name__)
@@ -415,7 +416,7 @@ def try_get_optimal_moe_config(
M: int,
is_marlin: bool = False,
):
from sglang.srt.layers.fused_moe_triton import get_config
from sglang.srt.layers.moe.fused_moe_triton import get_config
override_config = get_config()
if override_config:
@@ -435,74 +436,6 @@ def try_get_optimal_moe_config(
return config
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(), # TODO(woosuk): Optimize this.
)
del token_expert_indicies # Not used. Will be used in the future.
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def get_config_dtype_str(
dtype: torch.dtype,
use_int8_w8a16: Optional[bool] = False,
@@ -869,24 +802,16 @@ def fused_moe(
# Check constraints.
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states,
gating_output,
topk,
renormalize,
num_expert_group,
topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states, gating_output, topk, renormalize
)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=gating_output,
use_grouped_topk=use_grouped_topk,
top_k=topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
)
return fused_experts(
hidden_states,

View File

@@ -13,6 +13,7 @@ from vllm.distributed import (
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.custom_op_util import register_custom_op
from sglang.srt.layers.moe.topk import select_experts
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -20,7 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import set_weight_attrs
if torch.cuda.is_available():
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
else:
fused_experts = None # type: ignore
@@ -106,6 +107,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.forward(
x=x,
@@ -117,6 +119,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
def forward_cuda(
@@ -130,8 +133,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@@ -140,6 +144,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
return fused_experts(
@@ -197,6 +202,7 @@ class FusedMoE(torch.nn.Module):
tp_size: Optional[int] = None,
prefix: str = "",
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
):
super().__init__()
@@ -217,6 +223,7 @@ class FusedMoE(torch.nn.Module):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.correction_bias = correction_bias
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
@@ -503,51 +510,6 @@ class FusedMoE(torch.nn.Module):
)
return
@staticmethod
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
):
from sglang.srt.layers.fused_moe_triton.fused_moe import (
fused_topk,
grouped_topk,
)
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
assert self.quant_method is not None
@@ -562,6 +524,7 @@ class FusedMoE(torch.nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
correction_bias=self.correction_bias,
)
if self.reduce_results and self.tp_size > 1:

View File

@@ -0,0 +1,191 @@
from typing import Callable, Optional
import torch
import torch.nn.functional as F
def fused_topk_native(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_weights = F.softmax(gating_output.float(), dim=-1)
topk_weights, topk_ids = torch.topk(topk_weights, topk, dim=-1)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
def fused_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
):
from vllm import _custom_ops as ops
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
ops.topk_softmax(
topk_weights,
topk_ids,
token_expert_indicies,
gating_output.float(),
)
del token_expert_indicies
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
# This is used by the Deepseek-V2 model
def grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = torch.softmax(gating_output, dim=-1)
num_token = scores.shape[0]
group_scores = (
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def biased_grouped_topk(
hidden_states: torch.Tensor,
gating_output: torch.Tensor,
correction_bias: torch.Tensor,
topk: int,
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
):
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
scores = gating_output.sigmoid()
num_token = scores.shape[0]
scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
group_scores = (
scores_for_choice.view(num_token, num_expert_group, -1)
.topk(2, dim=-1)[0]
.sum(dim=-1)
) # [n, n_group]
group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
1
] # [n, top_k_group]
group_mask = torch.zeros_like(group_scores) # [n, n_group]
group_mask.scatter_(1, group_idx, 1) # [n, n_group]
score_mask = (
group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
.reshape(num_token, -1)
) # [n, e]
tmp_scores = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # [n, e]
_, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
topk_weights = scores.gather(1, topk_ids)
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
torch_native: bool = False,
):
# DeekSeekv2 uses grouped_top_k
if use_grouped_topk:
assert topk_group is not None
assert num_expert_group is not None
if correction_bias is None:
topk_weights, topk_ids = grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
else:
topk_weights, topk_ids = biased_grouped_topk(
hidden_states=hidden_states,
gating_output=router_logits,
correction_bias=correction_bias,
topk=top_k,
renormalize=renormalize,
num_expert_group=num_expert_group,
topk_group=topk_group,
)
elif torch_native:
topk_weights, topk_ids = fused_topk_native(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
elif custom_routing_function is None:
topk_weights, topk_ids = fused_topk(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
else:
topk_weights, topk_ids = custom_routing_function(
hidden_states=hidden_states,
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
)
return topk_weights, topk_ids

View File

@@ -60,8 +60,8 @@ def fp8_get_quant_method(self, layer, prefix):
is_layer_skipped,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.linear import UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod, Fp8MoEMethod
if isinstance(layer, LinearBase):
@@ -80,7 +80,7 @@ def gptq_get_quant_method(self, layer, prefix):
GPTQMarlinMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return GPTQMarlinLinearMethod(self)
@@ -96,7 +96,7 @@ def awq_get_quant_method(self, layer, prefix):
AWQMoEMethod,
)
from sglang.srt.layers.fused_moe_triton.layer import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
if isinstance(layer, LinearBase):
return AWQMarlinLinearMethod(self)

View File

@@ -26,8 +26,8 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
)
from vllm.model_executor.parameter import ModelWeightParameter, PerTensorScaleParameter
from sglang.srt.layers.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import padding_size
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
@@ -98,7 +98,7 @@ class Fp8Config(QuantizationConfig):
) -> Optional["QuantizeMethodBase"]:
from vllm.attention.layer import Attention # Avoid circular import
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
if is_layer_skipped(prefix, self.ignored_layers):
@@ -320,7 +320,7 @@ class Fp8MoEMethod:
"""
def __new__(cls, *args, **kwargs):
from sglang.srt.layers.fused_moe_triton import FusedMoEMethodBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
if not hasattr(cls, "_initialized"):
original_init = cls.__init__
@@ -349,7 +349,7 @@ class Fp8MoEMethod:
params_dtype: torch.dtype,
**extra_weight_attrs,
):
from sglang.srt.layers.fused_moe_triton import FusedMoeWeightScaleSupported
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.float8_e4m3fn
@@ -566,12 +566,14 @@ class Fp8MoEMethod:
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
correction_bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
from sglang.srt.layers.moe.topk import select_experts
# Expert selection
topk_weights, topk_ids = FusedMoE.select_experts(
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
@@ -580,6 +582,7 @@ class Fp8MoEMethod:
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
)
# Expert fusion with FP8 quantization

View File

@@ -25,12 +25,12 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.fused_moe_patch import fused_moe_forward_native
from sglang.srt.layers.logits_processor import (
LogitsMetadata,
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
from sglang.srt.utils import maybe_torch_compile, monkey_patch_vllm_all_gather

View File

@@ -27,13 +27,13 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -29,7 +29,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -19,6 +19,7 @@
from typing import Any, Dict, Iterable, Optional, Tuple
import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm import _custom_ops as ops
@@ -31,8 +32,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
ColumnParallelLinear,
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
return x
class MoEGate(nn.Module):
def __init__(self, config):
super().__init__()
self.weight = nn.Parameter(
torch.empty((config.n_routed_experts, config.hidden_size))
)
if config.topk_method == "noaux_tc":
self.e_score_correction_bias = nn.Parameter(
torch.empty((config.n_routed_experts))
)
else:
self.e_score_correction_bias = None
def forward(self, hidden_states):
logits = F.linear(hidden_states, self.weight, None)
return logits
class DeepseekV2MoE(nn.Module):
def __init__(
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
"Only silu is supported for now."
)
self.gate = MoEGate(config=config)
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
self.experts = MoEImpl(
num_experts=config.n_routed_experts,
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
use_grouped_topk=True,
num_expert_group=config.n_group,
topk_group=config.topk_group,
correction_bias=self.gate.e_score_correction_bias,
)
self.gate = ReplicatedLinear(
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
)
if config.n_shared_experts is not None:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
router_logits = self.gate(hidden_states)
final_hidden_states = (
self.experts(hidden_states=hidden_states, router_logits=router_logits)
* self.routed_scaling_factor
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
quant_config=quant_config,
)
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
rope_scaling["rope_type"] = "deepseek_yarn"
if rope_scaling:
rope_scaling["rope_type"] = "deepseek_yarn"
self.rotary_emb = get_rope(
qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
else:
self.rotary_emb.forward = self.rotary_emb.forward_native
self.attn_mqa = RadixAttention(
self.num_local_heads,

View File

@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -27,8 +27,6 @@ from vllm.distributed import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.ep_moe.layer import EPMoE
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -29,7 +29,6 @@ from vllm.distributed import (
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.fused_moe_triton import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (

View File

@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from sglang.srt.layers.fused_moe_triton import fused_moe
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.vocab_parallel_embedding import (