Files
2026-04-24 09:58:03 +08:00

106 lines
3.9 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
from typing import Optional, Callable
import torch
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm_mlu import _mlu_ops as mlu_ops
from vllm_mlu.mlu_hijack_utils import MluHijackObject
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
renormalize: bool,
use_grouped_topk: bool = False,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
enable_eplb: bool = False,
expert_load_view: Optional[torch.Tensor] = None,
logical_to_physical_map: Optional[torch.Tensor] = None,
logical_replica_count: Optional[torch.Tensor] = None,
) -> torch.Tensor:
#TODO: support `routed_scaling_factor`
assert routed_scaling_factor == 1.0, (
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
)
use_fused_kernel = topk_group is None
if use_fused_kernel:
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
return mlu_ops.fused_moe(
x,
router_logits,
layer.w13_weight, layer.w2_weight,
None, None, # bias1, bias2
None, # residual
None, # input_smooth
None, # act_smooth
None, None, # w1_scale, w2_scale
top_k,
renormalize,
True, # gated
activation
)
else:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
indices_type=self.topk_indices_dtype)
if self.rocm_aiter_moe_enabled:
assert expert_map is None
return self.rocm_aiter_fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
)
MluHijackObject.apply_hijack(
UnquantizedFusedMoEMethod,
UnquantizedFusedMoEMethod.forward_oot,
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
)