[Perf]enable prefill flashcommon3 (#4065)

### What this PR does / why we need it?
moe multistream overlap to improve the performance.

### How was this patch tested?
--additional-config '{"multistream_overlap_gate": true}'

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: chenxiao <Jaychou1620@Gmail.com>
Co-authored-by: clrs97 <524936896@qq.com>
Co-authored-by: zzhx1 <zzh_201018@outlook.com>
Co-authored-by: chenxiao <Jaychou1620@Gmail.com>
This commit is contained in:
AlvisGong
2025-12-14 09:34:13 +08:00
committed by GitHub
parent 0686b32d82
commit ba28d54f35
8 changed files with 239 additions and 40 deletions

View File

@@ -13,6 +13,10 @@ class TestPrepareAndFinalize(unittest.TestCase):
def setUp(self):
# Mock FusedMoEConfig
fake_stream = MagicMock()
patcher = patch("torch.npu.Stream", return_value=fake_stream)
patcher.start()
self.addCleanup(patcher.stop)
self.moe_config = MagicMock(spec=FusedMoEConfig)
self.moe_config.tp_group = MagicMock()
self.moe_config.tp_group.device_group = MagicMock()

View File

@@ -106,6 +106,8 @@ class AscendConfig:
enable_shared_expert_dp=True)
self.multistream_overlap_shared_expert = additional_config.get(
"multistream_overlap_shared_expert", False)
self.multistream_overlap_gate = additional_config.get(
"multistream_overlap_gate", False)
self.recompute_scheduler_enable = additional_config.get(
"recompute_scheduler_enable", False)
self.enable_cpu_binding = additional_config.get(

View File

@@ -20,9 +20,10 @@ _OTP: Optional[GroupCoordinator] = None
_LMTP: Optional[GroupCoordinator] = None
_EMBED_TP: Optional[GroupCoordinator] = None
# flashcomm2 specific groups
# flashcomm specific groups
_FLASHCOMM2_OTP: Optional[GroupCoordinator] = None
_FLASHCOMM2_ODP: Optional[GroupCoordinator] = None
_FC3_QUANT_X: Optional[GroupCoordinator] = None
# shared_weight across rank groups
_SHARED_WEIGHT: Optional[GroupCoordinator] = None
@@ -241,6 +242,15 @@ def init_ascend_model_parallel(parallel_config: ParallelConfig, ):
assert flashcomm2_otp_size == 1, "flashcomm2_o_shared is only supported when flashcomm2_otp_size is 1"
_SHARED_WEIGHT = _create_shared_weight_group("flashcomm2_o_shared")
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_FC3_QUANT_X = init_model_parallel_group(group_ranks,
get_world_group().local_rank,
backend,
group_name="fc3_quant_x")
def model_parallel_initialized():
return (_MC2 is not None)
@@ -296,6 +306,11 @@ def get_p_tp_group() -> GroupCoordinator:
return _P_TP
def get_fc3_quant_x_group() -> GroupCoordinator:
assert _FC3_QUANT_X is not None, ("fc3 quant x group is not initialized")
return _FC3_QUANT_X
def destroy_ascend_model_parallel():
global _MC2
if _MC2:
@@ -343,3 +358,8 @@ def destroy_ascend_model_parallel():
if _SHARED_WEIGHT:
_SHARED_WEIGHT.destroy()
_SHARED_WEIGHT = None
global _FC3_QUANT_X
if _FC3_QUANT_X:
_FC3_QUANT_X.destroy()
_FC3_QUANT_X = None

View File

@@ -2,8 +2,11 @@ import os
import torch
import torch.distributed as dist
from vllm.forward_context import get_forward_context
from vllm_ascend.distributed.parallel_state import get_p_tp_group
from vllm_ascend.distributed.parallel_state import (get_dp_group,
get_fc3_quant_x_group,
get_p_tp_group)
def kv_alltoall_and_rearrange(pd_tp_ratio: int, key: torch.Tensor,
@@ -59,3 +62,31 @@ def get_transfer_timeout_value():
'7')) # type: ignore
return int((4.096 * (2**hccl_rdma_timeout)) * hccl_rdma_retry_cnt // 1000 +
3000)
def fc3_all_gather_and_maybe_unpad_impl(x: torch.Tensor, ) -> torch.Tensor:
try:
forward_context = get_forward_context()
except AssertionError:
return x
x = get_fc3_quant_x_group().all_gather(x, 0)
dp_metadata = forward_context.dp_metadata
if dp_metadata is None:
pad_size = forward_context.pad_size
if pad_size > 0:
x = x[:-pad_size]
else:
# unpad
num_tokens_across_dp_cpu = dp_metadata.num_tokens_across_dp_cpu
result = torch.empty((num_tokens_across_dp_cpu.sum(), *x.shape[1:]),
device=x.device,
dtype=x.dtype)
dp_size = get_dp_group().world_size
x = x.view(dp_size, forward_context.padded_length, *x.shape[1:])
offset = 0
for idx in range(dp_size):
num_tokens_dp = num_tokens_across_dp_cpu[idx]
result[offset:offset + num_tokens_dp] = x[idx, :num_tokens_dp]
offset += num_tokens_dp
x = result
return x

View File

@@ -0,0 +1,42 @@
from dataclasses import dataclass
from typing import Optional
import torch
from vllm.model_executor.layers.linear import LinearBase
@dataclass
class FlashCommon3Context:
gate: Optional[LinearBase] = None
topk_weights: Optional[torch.Tensor] = None
topk_ids: Optional[torch.Tensor] = None
row_idx: Optional[torch.Tensor] = None
shared_experts: Optional[torch.nn.Module] = None
shared_out: Optional[torch.Tensor] = None
_flash_common3_context: Optional[FlashCommon3Context] = None
def get_flash_common3_context() -> Optional[FlashCommon3Context]:
return _flash_common3_context
def set_flash_common3_context(
topk_weights: Optional[torch.Tensor] = None,
topk_ids: Optional[torch.Tensor] = None,
shared_experts: Optional[torch.nn.Module] = None,
shared_out: Optional[torch.Tensor] = None,
):
global _flash_common3_context
if _flash_common3_context is None:
_flash_common3_context = FlashCommon3Context()
if topk_weights is not None:
_flash_common3_context.topk_weights = topk_weights
if topk_ids is not None:
_flash_common3_context.topk_ids = topk_ids
if shared_experts is not None:
_flash_common3_context.shared_experts = shared_experts
if shared_out is not None:
_flash_common3_context.shared_out = shared_out

View File

@@ -37,9 +37,12 @@ from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import determine_default_log2phy_map
from vllm_ascend.eplb.utils import moe_load_async_stream
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.ops.fused_moe.moe_comm_method import setup_moe_comm_method
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.quantization.w4a8_dynamic import \
AscendW4A8DynamicFusedMoEMethod
@@ -139,6 +142,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendFusedMoE(FusedMoE):
moe_counter = -1
gate_stream: Optional[torch.npu.Stream] = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -170,6 +174,10 @@ class AscendFusedMoE(FusedMoE):
self.expert_map_path = ascend_config.expert_map_path
self.global_redundant_expert_num = ascend_config.init_redundancy_expert
self.global_num_experts = num_experts + self.global_redundant_expert_num
# flashcommon3 gate stream
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
if self.multistream_overlap_gate and AscendFusedMoE.gate_stream is None:
AscendFusedMoE.gate_stream = torch.npu.Stream()
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
vllm_config = get_current_vllm_config()
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
@@ -332,6 +340,47 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance = forward_context.in_profile_run
forward_context = get_forward_context()
if self.multistream_overlap_gate:
assert AscendFusedMoE.gate_stream is not None
fc3_context = get_flash_common3_context()
assert fc3_context is not None
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(AscendFusedMoE.gate_stream,
enabled=self.multistream_overlap_gate):
# share_expert
assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)
topk_weights, topk_ids = select_experts(
hidden_states=hidden_states,
router_logits=router_logits,
top_k=self.top_k,
use_grouped_topk=self.use_grouped_topk,
renormalize=self.renormalize,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
global_num_experts=self.global_num_experts)
if isinstance(forward_context.moe_comm_method,
AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights,
topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
@@ -339,6 +388,10 @@ class AscendFusedMoE(FusedMoE):
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type)
# Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(AscendFusedMoE.gate_stream)
if isinstance(hidden_states, tuple):
hidden_states, pertoken_scale = hidden_states
else:
@@ -407,6 +460,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = ascend_config.multistream_overlap_shared_expert
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
@@ -443,30 +497,42 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def forward_impl(self, hidden_states: torch.Tensor,
router_logits: torch.Tensor):
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream().wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
# Note that currently we only support calculations in separate streams with aclgraph.
# Communication operations in another stream might cause unknown errors.
shared_out = self._shared_experts(hidden_states)
shared_out = None
if not self.multistream_overlap_gate:
# Make sure the shared experts stream begins after hidden_states are ready.
if self.multistream_overlap_shared_expert:
shared_experts_calculation_stream(
).wait_stream( # type: ignore
torch.npu.current_stream())
with npu_stream_switch(
shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
# Use a separate stream to run shared experts.
shared_out = self._shared_experts(hidden_states)
else:
set_flash_common3_context(shared_experts=self._shared_experts)
fused_output = AscendFusedMoE.forward_impl(
self,
hidden_states=hidden_states,
router_logits=router_logits,
)
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_ALLTOALL} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
if not self.multistream_overlap_gate:
# Make sure the default stream waits for the shared experts stream to finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2} \
and not shared_expert_dp_enabled():
shared_out = tensor_model_parallel_all_reduce(shared_out)
else:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
shared_out = fc3_context.shared_out
return shared_out, fused_output

View File

@@ -29,7 +29,10 @@ from vllm.distributed.parallel_state import (
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.utils import enable_sp, prefill_context_parallel_enable
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
prefill_context_parallel_enable)
class QuantType(Enum):
@@ -49,9 +52,14 @@ class PrepareAndFinalize(ABC):
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
quant_stream: Optional[torch.npu.Stream] = None
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
ascend_config = get_ascend_config()
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
if self.multistream_overlap_gate and PrepareAndFinalize.quant_stream is None:
PrepareAndFinalize.quant_stream = torch.npu.Stream()
@abstractmethod
def prepare(
@@ -335,12 +343,28 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
if self.multistream_overlap_gate:
assert PrepareAndFinalize.quant_stream is not None
PrepareAndFinalize.quant_stream.wait_stream(
torch.npu.current_stream())
with npu_stream_switch(PrepareAndFinalize.quant_stream,
enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
hidden_states)
else:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
if pertoken_scale is not None:
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
pertoken_scale, True, True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(
PrepareAndFinalize.quant_stream)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None

View File

@@ -26,6 +26,7 @@ from vllm.forward_context import get_forward_context
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.flash_common3_context import get_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_enable_nz
@@ -114,6 +115,7 @@ class AscendW8A8DynamicFusedMoEMethod:
self.use_aclgraph = (vllm_config.compilation_config.mode
== CompilationMode.VLLM_COMPILE
and not vllm_config.model_config.enforce_eager)
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate
self.dynamic_eplb = ascend_config.dynamic_eplb or ascend_config.expert_map_record_path
self.in_dtype = vllm_config.model_config.dtype
@@ -198,18 +200,25 @@ class AscendW8A8DynamicFusedMoEMethod:
assert router_logits.shape[
1] == global_num_experts - global_redundant_expert_num, "Number of global experts mismatch (excluding redundancy)"
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
topk_weights, topk_ids = None, None
if self.multistream_overlap_gate:
fc3_context = get_flash_common3_context()
assert fc3_context is not None
topk_weights = fc3_context.topk_weights
topk_ids = fc3_context.topk_ids
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
# this is a naive implementation for experts load balance so as
# to avoid accumulating too much tokens on a single rank.
@@ -222,6 +231,7 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
assert topk_weights is not None
topk_weights = topk_weights.to(self.in_dtype)
moe_comm_method = get_forward_context().moe_comm_method