################################################################################ # Copyright(c)2020-2025 Shanghai Biren Technology Co., Ltd. All rights reserved. # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ################################################################################ from typing import Optional import torch from torch import nn from transformers import PretrainedConfig from vllm.distributed import (get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import ReplicatedLinear from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.deepseek_v2 import (DeepseekV2MLP, ParallelConfig) from vllm_br import envs from vllm_br.utils import get_grandparent_pid class DeepseekV2MoE(nn.Module): def __init__( self, config: PretrainedConfig, parallel_config: ParallelConfig, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() self.routed_scaling_factor = config.routed_scaling_factor self.n_shared_experts = config.n_shared_experts self.static_moe_decoder_max_len = 512 self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if config.hidden_act != "silu": raise ValueError(f"Unsupported activation: {config.hidden_act}. " "Only silu is supported for now.") self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, bias=False, quant_config=None, prefix=f"{prefix}.gate") if config.topk_method == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, device="cpu")) else: self.gate.e_score_correction_bias = None self.experts = FusedMoE( num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, use_grouped_topk=True, num_expert_group=config.n_group, topk_group=config.topk_group, prefix=f"{prefix}.experts", scoring_func=config.scoring_func, e_score_correction_bias=self.gate.e_score_correction_bias) if config.n_shared_experts is not None: intermediate_size = (config.moe_intermediate_size * config.n_shared_experts) self.shared_experts = DeepseekV2MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, prefix=f"{prefix}.shared_experts", ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if envs.VLLM_BR_USE_CPU_ALL_REDUCE != 0 and not hasattr( self, "grandparent_pid"): self.grandparent_pid = get_grandparent_pid() orig_shape = hidden_states.shape assert self.n_shared_experts is not None, 'n_shared_experts must be set' # NOTE: gate has been fused with shared_experts, no more single gate call # and we packed router weights, shared_experts weights and down weights in a tuple tuple_router_shared_expert_weight = ( self.gate.weight, self.shared_experts.gate_up_proj.weight, self.shared_experts.down_proj.weight) hidden_states = hidden_states.view(-1, orig_shape[-1]) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=tuple_router_shared_expert_weight) if hasattr(final_hidden_states, 'all_reduced'): # NOTE: this flag indicates that the final_hidden_states has been reduced in fused_moe delattr(final_hidden_states, 'all_reduced') elif self.tp_size > 1: final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states) return final_hidden_states.view(orig_shape)