[MoE] [Refactor] Combine common_fused_moe and fused_moe (#3176)
### What this PR does / why we need it? 1. Move additional functionalities from fused_moe.py to common_fused_moe.py and remove fused_moe.py 2. Remove unnecessary custom classes from qwen3_moe.py, and it will be completely removed after we release vllm-ascend v0.11.0 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Qwen3-30B-A3B/Qwen3-30B-A3B-W8A8/DeepSeek-V3-W4A8-Pruing/deepseek-mtp/pangu-pro-moe-pruing: 1. Enable/Disable EP 3. Aclgraph & eager 4. SP - vLLM version: v0.11.0 --------- Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com> Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
This commit is contained in:
@@ -15,7 +15,7 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import os.path
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -23,6 +23,7 @@ from vllm.config import CompilationLevel, get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
|
||||
from vllm.model_executor.layers.shared_fused_moe import SharedFusedMoE
|
||||
@@ -37,99 +38,110 @@ from vllm_ascend.ops.moe.experts_selector import select_experts
|
||||
from vllm_ascend.ops.moe.moe_comm_method import setup_moe_comm_method
|
||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p, npu_stream_switch
|
||||
|
||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||
def __init__(self, moe: FusedMoEConfig = None):
|
||||
|
||||
# NOTE: Currently, this self.use_aclgraph is only used in
|
||||
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
||||
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
||||
# Once torch.randint_like is supported or removed, this flag can be removed.
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
self.use_aclgraph = False
|
||||
else:
|
||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE
|
||||
and not vllm_config.model_config.enforce_eager)
|
||||
self.transpose = True
|
||||
super().__init__(moe=moe)
|
||||
|
||||
# NOTE: Currently, this self.use_aclgraph is only used in
|
||||
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
||||
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
||||
# Once torch.randint_like is supported or removed, this flag can be removed.
|
||||
vllm_config = get_current_vllm_config()
|
||||
ascend_config = get_ascend_config()
|
||||
self.dynamic_eplb = get_ascend_config().dynamic_eplb
|
||||
if ascend_config.torchair_graph_config.enabled:
|
||||
self.use_aclgraph = False
|
||||
else:
|
||||
self.use_aclgraph = (vllm_config.compilation_config.level
|
||||
== CompilationLevel.PIECEWISE and
|
||||
not vllm_config.model_config.enforce_eager)
|
||||
self.transpose = True
|
||||
|
||||
def forward_oot(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
activation: str = "silu",
|
||||
enable_eplb: bool = False,
|
||||
expert_load_view: Optional[torch.Tensor] = None,
|
||||
logical_to_physical_map: Optional[torch.Tensor] = None,
|
||||
logical_replica_count: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
if self.transpose:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map)
|
||||
self.transpose = False
|
||||
else:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data,
|
||||
requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||
if self.transpose:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
|
||||
self.transpose = False
|
||||
else:
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data)
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
topk_weights, topk_ids, row_idx = select_experts(
|
||||
hidden_states=x,
|
||||
router_logits=router_logits,
|
||||
top_k=top_k,
|
||||
use_grouped_topk=use_grouped_topk,
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
custom_routing_function=custom_routing_function,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data)
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance and not self.use_aclgraph:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
if not is_310p():
|
||||
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
return moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
global_num_experts=global_num_experts,
|
||||
expert_map=expert_map,
|
||||
shared_experts=shared_experts,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb)
|
||||
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
@@ -138,8 +150,26 @@ class AscendFusedMoE(FusedMoE):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
num_experts = kwargs["num_experts"]
|
||||
intermediate_size = kwargs["intermediate_size"]
|
||||
|
||||
AscendFusedMoE.moe_counter += 1
|
||||
self.moe_instance_id = AscendFusedMoE.moe_counter
|
||||
|
||||
self.global_num_experts = num_experts
|
||||
self.expert_map = None
|
||||
self.log2phy = None
|
||||
self.global_redundant_expert_num = 0
|
||||
|
||||
if self.quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
||||
self.moe_config)
|
||||
else:
|
||||
self.quant_method = self.quant_config.get_quant_method(
|
||||
self, self.layer_name)
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
self.moe_config.tp_group = get_tp_group()
|
||||
self.moe_config.dp_group = get_dp_group()
|
||||
self.moe_config.ep_group = get_ep_group()
|
||||
@@ -148,6 +178,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.dynamic_eplb = ascend_config.dynamic_eplb
|
||||
self.expert_map_path = ascend_config.expert_map_path
|
||||
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
# static eplb initializing with expert_map_path
|
||||
if self.expert_map_path and os.path.exists(
|
||||
self.expert_map_path) and os.access(self.expert_map_path,
|
||||
@@ -180,6 +211,25 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(local_num_experts, dtype=torch.int64)
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
self.moe_config.original_num_experts = num_experts
|
||||
|
||||
moe_quant_params = {
|
||||
"num_experts": local_num_experts,
|
||||
"hidden_size": self.hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"params_dtype": self.params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
setup_moe_comm_method(self.moe_config)
|
||||
|
||||
def update_expert_map(self, new_expert_map):
|
||||
@@ -210,11 +260,18 @@ class AscendFusedMoE(FusedMoE):
|
||||
router_logits: torch.Tensor):
|
||||
assert self.quant_method is not None
|
||||
|
||||
# For w8a8 dynamic we can do npu_dynamic_quant and gate in parallel.
|
||||
quantized_x_for_share, dynamic_scale_for_share = None, None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
|
||||
forward_context = get_forward_context()
|
||||
hidden_states, router_logits = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled)
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
@@ -233,11 +290,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
activation=self.activation,
|
||||
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
||||
enable_eplb=self.enable_eplb,
|
||||
expert_load_view=self.expert_load_view,
|
||||
logical_to_physical_map=self.logical_to_physical_map,
|
||||
logical_replica_count=self.logical_replica_count,
|
||||
)
|
||||
quantized_x_for_share=quantized_x_for_share,
|
||||
dynamic_scale_for_share=dynamic_scale_for_share,
|
||||
shared_experts=None,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num)
|
||||
|
||||
if isinstance(final_hidden_states, tuple):
|
||||
final_hidden_states, group_list_type, expert_tokens = final_hidden_states
|
||||
|
||||
@@ -361,8 +420,3 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(self.shared_expert_stream)
|
||||
return shared_out, fused_output
|
||||
|
||||
|
||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||
|
||||
Reference in New Issue
Block a user