106 lines
3.9 KiB
Python
106 lines
3.9 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
|
|
|
from typing import Optional, Callable
|
|
|
|
import torch
|
|
|
|
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
|
|
|
from vllm_mlu import _mlu_ops as mlu_ops
|
|
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
|
|
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
|
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
|
|
|
|
|
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
routed_scaling_factor: float = 1.0,
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
apply_router_weight_on_input: bool = False,
|
|
activation: str = "silu",
|
|
enable_eplb: bool = False,
|
|
expert_load_view: Optional[torch.Tensor] = None,
|
|
logical_to_physical_map: Optional[torch.Tensor] = None,
|
|
logical_replica_count: Optional[torch.Tensor] = None,
|
|
) -> torch.Tensor:
|
|
#TODO: support `routed_scaling_factor`
|
|
assert routed_scaling_factor == 1.0, (
|
|
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
|
|
)
|
|
use_fused_kernel = topk_group is None
|
|
if use_fused_kernel:
|
|
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
|
|
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
|
|
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
|
|
return mlu_ops.fused_moe(
|
|
x,
|
|
router_logits,
|
|
layer.w13_weight, layer.w2_weight,
|
|
None, None, # bias1, bias2
|
|
None, # residual
|
|
None, # input_smooth
|
|
None, # act_smooth
|
|
None, None, # w1_scale, w2_scale
|
|
top_k,
|
|
renormalize,
|
|
True, # gated
|
|
activation
|
|
)
|
|
else:
|
|
topk_weights, topk_ids = FusedMoE.select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
use_grouped_topk=use_grouped_topk,
|
|
top_k=top_k,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
indices_type=self.topk_indices_dtype)
|
|
|
|
if self.rocm_aiter_moe_enabled:
|
|
assert expert_map is None
|
|
return self.rocm_aiter_fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
|
else:
|
|
return fused_experts(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
inplace=True,
|
|
activation=activation,
|
|
apply_router_weight_on_input=apply_router_weight_on_input,
|
|
global_num_experts=global_num_experts,
|
|
expert_map=expert_map,
|
|
)
|
|
|
|
|
|
MluHijackObject.apply_hijack(
|
|
UnquantizedFusedMoEMethod,
|
|
UnquantizedFusedMoEMethod.forward_oot,
|
|
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
|
|
) |