[main] [refactor] refactor fused_moe.py to enable token_dispatchers (#2570)
### What this PR does / why we need it?
Enable token_dispatcher to replace fused_experts_with_xxx in eager mode
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
e2e & ut
- vLLM version: v0.10.1.1
- vLLM main:
704432af3c
Signed-off-by: Pr0Wh1teGivee <calvin_zhu0210@outlook.com>
Co-authored-by: sherie <963372609@qq.com>
Co-authored-by: weijinqian0 <12153182+weijinqian0@users.noreply.github.com>
Co-authored-by: shiyuan680 <72335504+shiyuan680@users.noreply.github.com>
This commit is contained in:
@@ -26,9 +26,8 @@ from vllm.forward_context import get_forward_context
|
||||
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||
from vllm_ascend.quantization.w8a8_dynamic import (fused_experts_with_all2all,
|
||||
fused_experts_with_mc2)
|
||||
|
||||
|
||||
class AscendW4A8DynamicLinearMethod:
|
||||
@@ -291,48 +290,25 @@ class AscendW4A8DynamicFusedMoEMethod:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
quantized_x_for_share=shared_gate_up,
|
||||
dynamic_scale_for_share=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
else:
|
||||
# The current implementation of deepseek moe splits hidden_states
|
||||
# according to tp_size before they are feed into layers module.
|
||||
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||
# dispatch/combine tokens.
|
||||
return fused_experts_with_all2all(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
)
|
||||
|
||||
return unified_fused_experts_eager(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
w1_scale=layer.w13_weight_scale_second,
|
||||
w2_scale=layer.w2_weight_scale_second,
|
||||
w1_scale_bias=layer.w13_scale_bias,
|
||||
w2_scale_bias=layer.w2_scale_bias,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
row_idx=row_idx,
|
||||
expert_map=expert_map,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts,
|
||||
shared_gate_up=shared_gate_up,
|
||||
shared_dequant_scale=shared_dequant_scale,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
|
||||
def process_scale(self, weight: torch.Tensor, scale, per_group_scale):
|
||||
group_num, k, n = weight.shape
|
||||
|
||||
Reference in New Issue
Block a user