[Perf]enable prefill flashcommon3 (#4065)

### What this PR does / why we need it?
moe multistream overlap to improve the performance.

### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
AlvisGong
2025-12-14 09:34:13 +08:00
committed by GitHub
parent 0686b32d82
commit ba28d54f35
8 changed files with 239 additions and 40 deletions

View File

@@ -37,9 +37,12 @@ from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
@@ -139,6 +142,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendFusedMoE(FusedMoE):
moe_counter = -1
gate_stream: Optional[torch.npu.Stream] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -170,6 +174,10 @@ class AscendFusedMoE(FusedMoE):
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# flashcommon3 gate stream
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
AscendFusedMoE.gate_stream = torch.npu.Stream()
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
vllm_config = get_current_vllm_config()
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
@@ -332,6 +340,47 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context()
if self.multistream_overlap_gate:
assert AscendFusedMoE.gate_stream is not None
fc3_context = get_flash_common3_context()
assert fc3_context is not None
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(AscendFusedMoE.gate_stream,
enabled=self.multistream_overlap_gate):
# share_expert
assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
global_num_experts=self.global_num_experts)
if isinstance(forward_context.moe_comm_method,
AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights,
topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
@@ -339,6 +388,10 @@ class AscendFusedMoE(FusedMoE):
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type)
# Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
@@ -407,6 +460,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -443,30 +497,42 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream().wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
# Note that currently we only support calculations in separate streams with aclgraph.
# Communication operations in another stream might cause unknown errors.
shared_out = self._shared_experts(hidden_states)
shared_out = None
if not self.multistream_overlap_gate:
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream(
).wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(
shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
else:
set_flash_common3_context(shared_experts=self._shared_experts)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
if not self.multistream_overlap_gate:
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
shared_out = fc3_context.shared_out
return shared_out, fused_output