# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project from typing import Optional, Callable import torch from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm_mlu import _mlu_ops as mlu_ops from vllm_mlu.mlu_hijack_utils import MluHijackObject from vllm.model_executor.layers.fused_moe import FusedMoE from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot( self, layer: torch.nn.Module, x: torch.Tensor, router_logits: torch.Tensor, top_k: int, renormalize: bool, use_grouped_topk: bool = False, topk_group: Optional[int] = None, num_expert_group: Optional[int] = None, global_num_experts: int = -1, expert_map: Optional[torch.Tensor] = None, custom_routing_function: Optional[Callable] = None, scoring_func: str = "softmax", routed_scaling_factor: float = 1.0, e_score_correction_bias: Optional[torch.Tensor] = None, apply_router_weight_on_input: bool = False, activation: str = "silu", enable_eplb: bool = False, expert_load_view: Optional[torch.Tensor] = None, logical_to_physical_map: Optional[torch.Tensor] = None, logical_replica_count: Optional[torch.Tensor] = None, ) -> torch.Tensor: #TODO: support `routed_scaling_factor` assert routed_scaling_factor == 1.0, ( f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU." ) use_fused_kernel = topk_group is None if use_fused_kernel: assert not enable_eplb, f"MLU not support eplb in fused_moe kernel." assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \ f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet." return mlu_ops.fused_moe( x, router_logits, layer.w13_weight, layer.w2_weight, None, None, # bias1, bias2 None, # residual None, # input_smooth None, # act_smooth None, None, # w1_scale, w2_scale top_k, renormalize, True, # gated activation ) else: topk_weights, topk_ids = FusedMoE.select_experts( hidden_states=x, router_logits=router_logits, use_grouped_topk=use_grouped_topk, top_k=top_k, renormalize=renormalize, topk_group=topk_group, num_expert_group=num_expert_group, custom_routing_function=custom_routing_function, scoring_func=scoring_func, e_score_correction_bias=e_score_correction_bias, indices_type=self.topk_indices_dtype) if self.rocm_aiter_moe_enabled: assert expert_map is None return self.rocm_aiter_fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input) else: return fused_experts( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, inplace=True, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, global_num_experts=global_num_experts, expert_map=expert_map, ) MluHijackObject.apply_hijack( UnquantizedFusedMoEMethod, UnquantizedFusedMoEMethod.forward_oot, vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot )