[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.communication_op import \
data_parallel_reduce_scatter
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_all_reduce_merge_state, get_fused_moe_state,
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
get_ascend_soc_version,
get_rm_router_logits_state, is_310p)
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -127,9 +130,23 @@ def fused_experts_with_mc2(
moe_parallel_config: FusedMoEParallelConfig,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
shared_experts: Optional[Any] = None
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
global_bs = 0
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
moe_expert_num = len(expert_map)
kwargs_mc2 = {
"x": hidden_states,
@@ -137,32 +154,35 @@ def fused_experts_with_mc2(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
"global_bs": 0,
}
rank = torch.distributed.get_rank()
quant_mode = 0
ep_rank_id = moe_parallel_config.ep_rank
ep_world_size = moe_parallel_config.ep_size
tp_world_size = moe_parallel_config.tp_size
tp_rank = rank % tp_world_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
@@ -205,7 +225,6 @@ def fused_experts_with_mc2(
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
@@ -218,15 +237,33 @@ def fused_experts_with_mc2(
"group_ep": moe_all_to_all_group_name,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_world_size,
"tp_rank_id": tp_rank,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
@@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe)
vllm_config = get_current_vllm_config()
self.ep_group = get_ep_group()
self.ep_size = self.moe.moe_parallel_config.ep_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = self.ep_group.device_group
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill,
is_deepseek_v3_r1)
fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2(
hidden_states=x,
@@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts)
shared_experts=shared_experts,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]:
@@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE):
real_top_k = self.top_k
num_tokens, hidden_size = hidden_states.shape
is_deepseek_v3_r1 = self.global_num_experts == 256
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
is_prefill, is_deepseek_v3_r1)
forward_context = get_forward_context()
fused_moe_state = forward_context.fused_moe_state
mc2_mask = forward_context.mc2_mask
if shared_experts:
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
shared_hidden_states = shared_experts(hidden_states)
tp_size = get_tensor_model_parallel_world_size()
if (tp_size > 1 and fused_moe_state not in [
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
if num_tokens < tp_size:
if fused_moe_state in {FusedMoEState.MC2}:
padding_size = forward_context.padded_num_tokens
else:
# TODO: Determine if we can remove the padding
padding_size = tp_size
if num_tokens < padding_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, tp_size - num_tokens))
hidden_states, (0, 0, 0, padding_size - num_tokens))
router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
router_logits, (0, 0, 0, padding_size - num_tokens))
if tp_size > 1:
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
mc2_mask = chunk_mc2_mask[tp_rank]
if self.dp_size > 1:
if fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0,
max_num_tokens_across_dp - num_tokens))
max_tokens_across_dp = forward_context.max_tokens_across_dp
if num_tokens < max_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_tokens_across_dp - num_tokens))
if not self.rm_router_logits:
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0)
if self.rm_router_logits:
router_logits, _ = gate(hidden_states)
@@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE):
global_redundant_expert_num=self.global_redundant_expert_num,
shared_experts=shared_experts if self.torchair_graph_enabled
and self.enable_multistream_moe and not is_prefill else None,
mc2_mask=mc2_mask,
)
if shared_experts:
if isinstance(e_hidden_states, tuple):
e_hidden_states, shared_hidden_states = e_hidden_states
if (tp_size > 1 and fused_moe_state not in [
if (fused_moe_state not in [
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
FusedMoEState.NaiveMulticast
] and not replace_allreduce):
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < tp_size:
if tp_size > 1:
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
else:
final_hidden_states = e_hidden_states
if num_tokens < padding_size:
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
elif self.dp_size > 1: