[perf]Support MOE Multi-stream in Deepseek (#947)
### What this PR does / why we need it? Support MOE inner Multi-stream for Deepseek. This feature requires graph mode with mc2 enabled. --------- Signed-off-by: David9857 <985700846@qq.com>
This commit is contained in:
@@ -114,5 +114,6 @@ def test_ascend_config_load_error():
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
with VllmRunner("facebook/opt-125m",
|
with VllmRunner("facebook/opt-125m",
|
||||||
|
enforce_eager=False,
|
||||||
additional_config=input_additional_config_fake_2):
|
additional_config=input_additional_config_fake_2):
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -53,6 +53,8 @@ class TorchairGraphConfig:
|
|||||||
"graph_batch_sizes", [])
|
"graph_batch_sizes", [])
|
||||||
self.graph_batch_sizes_init = torchair_graph_config.get(
|
self.graph_batch_sizes_init = torchair_graph_config.get(
|
||||||
"graph_batch_sizes_init", False)
|
"graph_batch_sizes_init", False)
|
||||||
|
self.enable_multistream_shared_expert = torchair_graph_config.get(
|
||||||
|
"enable_multistream_shared_expert", False)
|
||||||
|
|
||||||
if not isinstance(self.graph_batch_sizes, list):
|
if not isinstance(self.graph_batch_sizes, list):
|
||||||
raise TypeError("graph_batch_sizes must be list[int]")
|
raise TypeError("graph_batch_sizes must be list[int]")
|
||||||
@@ -105,7 +107,7 @@ def check_ascend_config(vllm_config, enforce_eager):
|
|||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
|
# Both for V0 and V1 Engine, torchair_graph cannot be enabled with eager mode.
|
||||||
if ascend_config.torchair_graph_config.enabled and not enforce_eager:
|
if ascend_config.torchair_graph_config.enabled and enforce_eager:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
"Can't enable graph mode and eager mode at the same time. Please set `enforce_eager=False` if you attempt to enable NPU graph mode."
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -216,6 +216,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
self.enable_multistream_shared_expert = \
|
||||||
|
ascend_config.torchair_graph_config.enable_multistream_shared_expert
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
@@ -238,6 +240,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
|
|
||||||
num_tokens, hidden_size = hidden_states.shape
|
num_tokens, hidden_size = hidden_states.shape
|
||||||
|
|
||||||
|
multistream = self.enable_multistream_shared_expert and not is_prefill
|
||||||
|
|
||||||
old_hidden_states = hidden_states.clone()
|
old_hidden_states = hidden_states.clone()
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
@@ -259,13 +263,25 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
# router_logits: (num_tokens, n_experts)
|
# router_logits: (num_tokens, n_experts)
|
||||||
router_logits, _ = self.gate(hidden_states)
|
router_logits, _ = self.gate(hidden_states)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if multistream:
|
||||||
|
kwargs.update({
|
||||||
|
"shared_experts": self.shared_experts,
|
||||||
|
"shared_hidden_states": old_hidden_states
|
||||||
|
})
|
||||||
|
|
||||||
hidden_states = self.experts(
|
hidden_states = self.experts(
|
||||||
hidden_states=hidden_states,
|
hidden_states=hidden_states,
|
||||||
router_logits=router_logits,
|
router_logits=router_logits,
|
||||||
is_prefill=is_prefill,
|
is_prefill=is_prefill,
|
||||||
top_k=CustomDeepseekV2MoE.top_k,
|
top_k=CustomDeepseekV2MoE.top_k,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
) * self.routed_scaling_factor
|
**kwargs)
|
||||||
|
|
||||||
|
if multistream:
|
||||||
|
hidden_states, shared_output = hidden_states
|
||||||
|
|
||||||
|
hidden_states = hidden_states * self.routed_scaling_factor
|
||||||
|
|
||||||
if self.tp_size > 1:
|
if self.tp_size > 1:
|
||||||
if self.torchair_graph_enabled:
|
if self.torchair_graph_enabled:
|
||||||
@@ -288,6 +304,7 @@ class CustomDeepseekV2MoE(nn.Module):
|
|||||||
hidden_states = hidden_states[:-num_padding_tokens]
|
hidden_states = hidden_states[:-num_padding_tokens]
|
||||||
|
|
||||||
if self.n_shared_experts is not None:
|
if self.n_shared_experts is not None:
|
||||||
|
if not multistream:
|
||||||
shared_output = self.shared_experts(old_hidden_states)
|
shared_output = self.shared_experts(old_hidden_states)
|
||||||
|
|
||||||
if shared_output is not None:
|
if shared_output is not None:
|
||||||
|
|||||||
@@ -39,8 +39,7 @@ VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
|||||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_with_mc2(
|
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
@@ -48,10 +47,10 @@ def fused_experts_with_mc2(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
expert_map: torch.Tensor = None,
|
expert_map: torch.Tensor = None,
|
||||||
moe_all_to_all_group_name: Optional[str] = None,
|
moe_all_to_all_group_name: Optional[str] = None,
|
||||||
) -> torch.Tensor:
|
**kwargs) -> torch.Tensor:
|
||||||
global_bs = 0
|
global_bs = 0
|
||||||
moe_expert_num = len(expert_map)
|
moe_expert_num = len(expert_map)
|
||||||
kwargs = {
|
kwargs_mc2 = {
|
||||||
"x": hidden_states,
|
"x": hidden_states,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
@@ -81,9 +80,9 @@ def fused_experts_with_mc2(
|
|||||||
"tp_world_size": tp_size,
|
"tp_world_size": tp_size,
|
||||||
"tp_rank_id": tp_rank,
|
"tp_rank_id": tp_rank,
|
||||||
}
|
}
|
||||||
kwargs.update(stage1_kwargs)
|
kwargs_mc2.update(stage1_kwargs)
|
||||||
|
|
||||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||||
0:5]
|
0:5]
|
||||||
@@ -119,7 +118,7 @@ def fused_experts_with_mc2(
|
|||||||
down_out_list = torch.cat(down_out_list, dim=0)
|
down_out_list = torch.cat(down_out_list, dim=0)
|
||||||
|
|
||||||
# moeCombine
|
# moeCombine
|
||||||
kwargs = {
|
kwargs_mc2 = {
|
||||||
"expand_x": down_out_list,
|
"expand_x": down_out_list,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expand_idx": expand_idx,
|
"expand_idx": expand_idx,
|
||||||
@@ -141,9 +140,9 @@ def fused_experts_with_mc2(
|
|||||||
"tp_world_size": tp_size,
|
"tp_world_size": tp_size,
|
||||||
"tp_rank_id": tp_rank,
|
"tp_rank_id": tp_rank,
|
||||||
}
|
}
|
||||||
kwargs.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
|
|
||||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
@@ -675,7 +674,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
|
**kwargs)
|
||||||
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -772,6 +772,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||||
|
self.enable_multistream_shared_expert = \
|
||||||
|
ascend_config.torchair_graph_config.enable_multistream_shared_expert
|
||||||
|
|
||||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||||
raise ValueError("Only softmax scoring function is supported for "
|
raise ValueError("Only softmax scoring function is supported for "
|
||||||
@@ -818,7 +820,8 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
router_logits: torch.Tensor,
|
router_logits: torch.Tensor,
|
||||||
is_prefill: bool,
|
is_prefill: bool,
|
||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
top_k=None):
|
top_k=None,
|
||||||
|
**kwargs):
|
||||||
assert self.quant_method is not None
|
assert self.quant_method is not None
|
||||||
|
|
||||||
if top_k:
|
if top_k:
|
||||||
@@ -862,7 +865,11 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
scoring_func=self.scoring_func,
|
scoring_func=self.scoring_func,
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
is_prefill=is_prefill,
|
is_prefill=is_prefill,
|
||||||
enable_force_load_balance=enable_force_load_balance)
|
enable_force_load_balance=enable_force_load_balance,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
if self.enable_multistream_shared_expert and not is_prefill:
|
||||||
|
hidden_states, shared_output = hidden_states
|
||||||
|
|
||||||
if self.dp_size > 1:
|
if self.dp_size > 1:
|
||||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
@@ -886,4 +893,6 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||||
|
|
||||||
|
if self.enable_multistream_shared_expert and not is_prefill:
|
||||||
|
return hidden_states, shared_output
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|||||||
@@ -329,7 +329,7 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||||
is_prefill, enable_force_load_balance)
|
is_prefill, enable_force_load_balance, **kwargs)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
|
|||||||
@@ -20,7 +20,8 @@ from typing import Any, Callable, Dict, Optional
|
|||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.distributed import GroupCoordinator
|
import torchair as tng # type: ignore
|
||||||
|
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
@@ -38,7 +39,8 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
w2_scale: torch.Tensor,
|
w2_scale: torch.Tensor,
|
||||||
group_list: torch.Tensor,
|
group_list: torch.Tensor,
|
||||||
dynamic_scale: torch.Tensor = None,
|
dynamic_scale: torch.Tensor = None,
|
||||||
group_list_type: int = 1) -> torch.Tensor:
|
group_list_type: int = 1,
|
||||||
|
**kwargs) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||||
|
|
||||||
@@ -72,6 +74,23 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
else:
|
else:
|
||||||
pertoken_scale = dynamic_scale
|
pertoken_scale = dynamic_scale
|
||||||
|
|
||||||
|
shared_experts = kwargs.get('shared_experts', None)
|
||||||
|
if shared_experts:
|
||||||
|
shared_gate_up = kwargs.get('shared_gate_up', None)
|
||||||
|
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
|
||||||
|
with tng.scope.npu_stream_switch('cv'):
|
||||||
|
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
|
||||||
|
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
|
x=shared_gate_up,
|
||||||
|
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
|
||||||
|
activation_scale=shared_dynamic_scale,
|
||||||
|
bias=None,
|
||||||
|
quant_scale=None,
|
||||||
|
quant_offset=None,
|
||||||
|
group_index=None,
|
||||||
|
activate_left=True,
|
||||||
|
quant_mode=1)
|
||||||
|
|
||||||
# gmm1: gate_up_proj
|
# gmm1: gate_up_proj
|
||||||
hidden_states = torch_npu.npu_grouped_matmul(
|
hidden_states = torch_npu.npu_grouped_matmul(
|
||||||
x=[hidden_states],
|
x=[hidden_states],
|
||||||
@@ -100,11 +119,25 @@ def apply_mlp(hidden_states: torch.Tensor,
|
|||||||
group_type=0,
|
group_type=0,
|
||||||
group_list=group_list,
|
group_list=group_list,
|
||||||
output_dtype=w2_scale.dtype)[0]
|
output_dtype=w2_scale.dtype)[0]
|
||||||
|
|
||||||
|
if shared_experts:
|
||||||
|
with tng.scope.npu_stream_switch('cv'):
|
||||||
|
tng.scope.npu_wait_tensor(shared_x, hidden_states)
|
||||||
|
shared_output = torch_npu.npu_quant_matmul(
|
||||||
|
shared_x,
|
||||||
|
shared_experts.down_proj.weight,
|
||||||
|
shared_experts.down_proj.weight_scale,
|
||||||
|
pertoken_scale=shared_dynamic_scale,
|
||||||
|
output_dtype=torch.bfloat16,
|
||||||
|
)
|
||||||
|
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
|
||||||
|
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||||
|
if shared_experts:
|
||||||
|
return hidden_states, shared_output
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def fused_experts_with_mc2(
|
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||||
hidden_states: torch.Tensor,
|
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
w2: torch.Tensor,
|
w2: torch.Tensor,
|
||||||
w1_scale: torch.Tensor,
|
w1_scale: torch.Tensor,
|
||||||
@@ -114,11 +147,11 @@ def fused_experts_with_mc2(
|
|||||||
top_k: int,
|
top_k: int,
|
||||||
expert_map: torch.Tensor = None,
|
expert_map: torch.Tensor = None,
|
||||||
moe_all_to_all_group_name: str = "",
|
moe_all_to_all_group_name: str = "",
|
||||||
) -> torch.Tensor:
|
**kwargs) -> torch.Tensor:
|
||||||
global_bs = 0
|
global_bs = 0
|
||||||
moe_expert_num = len(expert_map)
|
moe_expert_num = len(expert_map)
|
||||||
# hidden_states = hidden_states.bfloat16()
|
# hidden_states = hidden_states.bfloat16()
|
||||||
kwargs = {
|
kwargs_mc2 = {
|
||||||
"x": hidden_states,
|
"x": hidden_states,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expert_shard_type": 0,
|
"expert_shard_type": 0,
|
||||||
@@ -149,9 +182,27 @@ def fused_experts_with_mc2(
|
|||||||
"tp_world_size": tp_size,
|
"tp_world_size": tp_size,
|
||||||
"tp_rank_id": tp_rank,
|
"tp_rank_id": tp_rank,
|
||||||
}
|
}
|
||||||
kwargs.update(stage1_kwargs)
|
kwargs_mc2.update(stage1_kwargs)
|
||||||
|
|
||||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
|
shared_experts = kwargs.get('shared_experts', None)
|
||||||
|
if shared_experts:
|
||||||
|
shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
||||||
|
with tng.scope.npu_stream_switch('cv'):
|
||||||
|
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
||||||
|
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||||
|
shared_hidden_states)
|
||||||
|
shared_gate_up = torch_npu.npu_quant_matmul(
|
||||||
|
shared_x,
|
||||||
|
shared_experts.gate_up_proj.weight,
|
||||||
|
shared_experts.gate_up_proj.weight_scale,
|
||||||
|
output_dtype=torch.int32,
|
||||||
|
)
|
||||||
|
kwargs.update({
|
||||||
|
"shared_gate_up": shared_gate_up,
|
||||||
|
"shared_dynamic_scale": shared_dynamic_scale,
|
||||||
|
})
|
||||||
|
|
||||||
|
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||||
0:5]
|
0:5]
|
||||||
@@ -166,10 +217,15 @@ def fused_experts_with_mc2(
|
|||||||
w2,
|
w2,
|
||||||
w2_scale,
|
w2_scale,
|
||||||
expert_token_nums,
|
expert_token_nums,
|
||||||
dynamic_scale=dynamic_scale)
|
dynamic_scale=dynamic_scale,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
multi_stream = isinstance(down_out_list, tuple)
|
||||||
|
if multi_stream:
|
||||||
|
down_out_list, shared_output = down_out_list
|
||||||
|
|
||||||
# moeCombine
|
# moeCombine
|
||||||
kwargs = {
|
kwargs_mc2 = {
|
||||||
"expand_x": down_out_list,
|
"expand_x": down_out_list,
|
||||||
"expert_ids": topk_ids,
|
"expert_ids": topk_ids,
|
||||||
"expand_idx": expand_idx,
|
"expand_idx": expand_idx,
|
||||||
@@ -193,10 +249,12 @@ def fused_experts_with_mc2(
|
|||||||
"tp_world_size": tp_size,
|
"tp_world_size": tp_size,
|
||||||
"tp_rank_id": tp_rank,
|
"tp_rank_id": tp_rank,
|
||||||
}
|
}
|
||||||
kwargs.update(stage3_kwargs)
|
kwargs_mc2.update(stage3_kwargs)
|
||||||
|
|
||||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
|
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||||
|
|
||||||
|
if multi_stream:
|
||||||
|
return hidden_states, shared_output
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
@@ -634,7 +692,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||||
|
**kwargs)
|
||||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
|
|||||||
Reference in New Issue
Block a user