[Model] Support DeepSeek-V4
This commit is contained in:
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
106
vllm_mlu/model_executor/layers/fused_moe/layer.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM-MLU project
|
||||
|
||||
from typing import Optional, Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
|
||||
|
||||
from vllm_mlu import _mlu_ops as mlu_ops
|
||||
from vllm_mlu.mlu_hijack_utils import MluHijackObject
|
||||
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||
from vllm_mlu.model_executor.layers.fused_moe.fused_moe import fused_experts
|
||||
|
||||
|
||||
def vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
use_grouped_topk: bool = False,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
#TODO: support `routed_scaling_factor`
|
||||
assert routed_scaling_factor == 1.0, (
|
||||
f"routed_scaling_factor {routed_scaling_factor} is not supported for MLU."
|
||||
)
|
||||
use_fused_kernel = topk_group is None
|
||||
if use_fused_kernel:
|
||||
assert not enable_eplb, f"MLU not support eplb in fused_moe kernel."
|
||||
assert use_grouped_topk is False and num_expert_group is None and topk_group is None, \
|
||||
f"Following params: use_grouped_topk, num_expert_group, topk_group are not support yet."
|
||||
return mlu_ops.fused_moe(
|
||||
x,
|
||||
router_logits,
|
||||
layer.w13_weight, layer.w2_weight,
|
||||
None, None, # bias1, bias2
|
||||
None, # residual
|
||||
None, # input_smooth
|
||||
None, # act_smooth
|
||||
None, None, # w1_scale, w2_scale
|
||||
top_k,
|
||||
renormalize,
|
||||
True, # gated
|
||||
activation
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = FusedMoE.select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
top_k=top_k,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
indices_type=self.topk_indices_dtype)
|
||||
|
||||
if self.rocm_aiter_moe_enabled:
|
||||
assert expert_map is None
|
||||
return self.rocm_aiter_fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||
else:
|
||||
return fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
inplace=True,
|
||||
activation=activation,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
)
|
||||
|
||||
|
||||
MluHijackObject.apply_hijack(
|
||||
UnquantizedFusedMoEMethod,
|
||||
UnquantizedFusedMoEMethod.forward_oot,
|
||||
vllm__model_executor__layers__fused_moe__layer__UnquantizedFusedMoEMethod__forward_oot
|
||||
)
|
||||
Reference in New Issue
Block a user