Reorg moe code (#2563)
This commit is contained in:
@@ -27,13 +27,13 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -19,6 +19,7 @@
|
||||
from typing import Any, Dict, Iterable, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
from vllm import _custom_ops as ops
|
||||
@@ -31,8 +32,6 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -41,6 +40,8 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
@@ -90,6 +91,24 @@ class DeepseekV2MLP(nn.Module):
|
||||
return x
|
||||
|
||||
|
||||
class MoEGate(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts, config.hidden_size))
|
||||
)
|
||||
if config.topk_method == "noaux_tc":
|
||||
self.e_score_correction_bias = nn.Parameter(
|
||||
torch.empty((config.n_routed_experts))
|
||||
)
|
||||
else:
|
||||
self.e_score_correction_bias = None
|
||||
|
||||
def forward(self, hidden_states):
|
||||
logits = F.linear(hidden_states, self.weight, None)
|
||||
return logits
|
||||
|
||||
|
||||
class DeepseekV2MoE(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -114,6 +133,8 @@ class DeepseekV2MoE(nn.Module):
|
||||
"Only silu is supported for now."
|
||||
)
|
||||
|
||||
self.gate = MoEGate(config=config)
|
||||
|
||||
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
||||
self.experts = MoEImpl(
|
||||
num_experts=config.n_routed_experts,
|
||||
@@ -125,11 +146,9 @@ class DeepseekV2MoE(nn.Module):
|
||||
use_grouped_topk=True,
|
||||
num_expert_group=config.n_group,
|
||||
topk_group=config.topk_group,
|
||||
correction_bias=self.gate.e_score_correction_bias,
|
||||
)
|
||||
|
||||
self.gate = ReplicatedLinear(
|
||||
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
|
||||
)
|
||||
if config.n_shared_experts is not None:
|
||||
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
||||
self.shared_experts = DeepseekV2MLP(
|
||||
@@ -146,7 +165,7 @@ class DeepseekV2MoE(nn.Module):
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
router_logits = self.gate(hidden_states)
|
||||
final_hidden_states = (
|
||||
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
||||
* self.routed_scaling_factor
|
||||
@@ -439,7 +458,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
quant_config=quant_config,
|
||||
)
|
||||
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
if rope_scaling:
|
||||
rope_scaling["rope_type"] = "deepseek_yarn"
|
||||
|
||||
self.rotary_emb = get_rope(
|
||||
qk_rope_head_dim,
|
||||
rotary_dim=qk_rope_head_dim,
|
||||
@@ -454,6 +476,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
scaling_factor = rope_scaling["factor"]
|
||||
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
||||
self.scaling = self.scaling * mscale * mscale
|
||||
else:
|
||||
self.rotary_emb.forward = self.rotary_emb.forward_native
|
||||
|
||||
self.attn_mqa = RadixAttention(
|
||||
self.num_local_heads,
|
||||
|
||||
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import GeluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -27,8 +27,6 @@ from vllm.distributed import (
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
QKVParallelLinear,
|
||||
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
MergedColumnParallelLinear,
|
||||
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
|
||||
)
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
|
||||
from sglang.srt.layers.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.layers.radix_attention import RadixAttention
|
||||
from sglang.srt.layers.vocab_parallel_embedding import (
|
||||
|
||||
Reference in New Issue
Block a user