[Feature] support aclgraph for model runner v2 (#7110)
### What this PR does / why we need it?
This PR aims to support aclgraph for model runner v2, please see RFC
#5208. The PR contains these modifications:
- adapt to newest commit of vllm main branch.
- supply a unified interface of extra forward context for both model
runner v1 and model runner v2.
- implement graph mode for main model.
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
- vLLM version: v0.16.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: Ronald1995 <ronaldautomobile@163.com>
This commit is contained in:
@@ -37,7 +37,7 @@ if not vllm_version_is("0.16.0"):
|
||||
from vllm.model_executor.layers.fused_moe.runner.default_moe_runner import DefaultMoERunner # type: ignore
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ascend_forward_context import _EXTRA_CTX, MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||
@@ -148,7 +148,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
moe_comm_method = _EXTRA_CTX.moe_comm_method
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -401,12 +401,13 @@ class AscendFusedMoE(FusedMoE):
|
||||
# When static kernels are enabled, the forward pass runs twice (compilation + capture),
|
||||
# causing moe_layer_index to overflow. Wrap the index to prevent out-of-bounds errors.
|
||||
if self.enable_npugraph_ex_static_kernel:
|
||||
forward_context.moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
||||
moe_layer_index = forward_context.moe_layer_index % (len(forward_context.all_moe_layers))
|
||||
forward_context.moe_layer_index = moe_layer_index
|
||||
|
||||
# Load balancing for token distribution among experts in dummy_run
|
||||
# TODO: The community only considers load balancing when DP > 1.
|
||||
# This approach may overlook some extreme scenarios.
|
||||
enable_force_load_balance = forward_context.in_profile_run
|
||||
enable_force_load_balance = _EXTRA_CTX.in_profile_run
|
||||
|
||||
forward_context = get_forward_context()
|
||||
if self.multistream_overlap_gate:
|
||||
@@ -419,7 +420,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
assert fc3_context.shared_experts is not None
|
||||
shared_out = fc3_context.shared_experts(hidden_states)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
@@ -442,16 +443,16 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_num_experts=self.global_num_experts,
|
||||
)
|
||||
|
||||
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
|
||||
if isinstance(_EXTRA_CTX.moe_comm_method, AllGatherCommImpl):
|
||||
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
||||
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
||||
|
||||
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = _EXTRA_CTX.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.flash_comm_v1_enabled,
|
||||
replace_allreduce=_EXTRA_CTX.flash_comm_v1_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
@@ -509,7 +510,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.load_counter.add_(1)
|
||||
else:
|
||||
self.moe_load.add_(local_load)
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
routed_out = _EXTRA_CTX.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata,
|
||||
@@ -670,8 +671,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
# NOTE: This is exactly the opposite of
|
||||
# `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
moe_comm_type = _EXTRA_CTX.moe_comm_type
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
|
||||
Reference in New Issue
Block a user