[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)

### What this PR does / why we need it?

A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:

Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;

### Does this PR introduce _any_ user-facing change?

No change at user-facing.

### How was this patch tested?


- vLLM version: v0.10.0
- vLLM main:
57c22e57f9

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-07-28 14:06:20 +08:00
committed by GitHub
parent e3a2443c3a
commit ba3dfbd59e
22 changed files with 629 additions and 347 deletions

View File

@@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional
from vllm.logger import logger
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
wrapper_rmsnorm_init)
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
AscendW8A8LinearMethod)
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
@@ -81,9 +80,6 @@ class VLLMAscendQuantizer:
VLLMAscendQuantizer.apply_patch(
"vllm.model_executor.layers.layernorm.RMSNorm",
"forward_oot", [wrapper_rmsnorm_forward_oot])
VLLMAscendQuantizer.apply_patch(
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
"load_model", [wrapper_load_model])
break
VLLMAscendQuantizer.patched = True
logger.info("Using the vLLM Ascend Quantizer version now!")

View File

@@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
import torch.distributed as dist
import torch_npu
from vllm.distributed import GroupCoordinator
from vllm.distributed.parallel_state import get_ep_group
from vllm.distributed import GroupCoordinator, get_ep_group
from vllm.forward_context import get_forward_context
import vllm_ascend.envs as envs
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import FusedMoEState
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
dispose_tensor, get_fused_moe_state)
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
dispose_tensor, get_ascend_soc_version)
def apply_mlp(hidden_states: torch.Tensor,
@@ -118,10 +120,29 @@ def fused_experts_with_mc2(
log2phy: torch.Tensor = None,
global_redundant_expert_num: int = 0,
shared_experts: Optional[Any] = None,
is_torchair: bool = False,
quantized_x_for_share: Optional[Any] = None,
dynamic_scale_for_share: Optional[Any] = None,
mc2_mask: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
assert mc2_mask is not None
if log2phy is not None:
topk_ids = log2phy[topk_ids]
global_bs = 0
quant_mode = 2
ep_group = get_mc2_group()
ep_rank_id = ep_group.rank_in_group
ep_world_size = ep_group.world_size
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
or is_torchair)
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
if (expert_map is not None):
moe_expert_num = len(expert_map) + global_redundant_expert_num
else:
@@ -133,47 +154,43 @@ def fused_experts_with_mc2(
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
"expert_scales": topk_weights.to(torch.float32),
"global_bs": 0,
}
rank = torch.distributed.get_rank()
quant_mode = 2
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_size = torch.distributed.get_world_size()
tp_size = world_size // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if need_extra_args:
stage1_kwargs.update({
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage1_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
0:7]
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
0:5]
if shared_experts is not None:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(hidden_states, topk_weights)
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
npu_wait_tensor(shared_gate_up[0], expand_x)
shared_act = shared_experts.act_fn(shared_gate_up)
npu_wait_tensor(quantized_x_for_share, expand_x)
shared_act_out = shared_experts.act_fn(
(quantized_x_for_share, dynamic_scale_for_share))
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
# `expand_x` will be disposed in the `apply_mlp` function
down_out_list = apply_mlp(expand_x,
w1,
w1_scale,
@@ -186,13 +203,11 @@ def fused_experts_with_mc2(
kwargs_mc2 = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
"expand_scales": expand_scales,
}
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
@@ -200,24 +215,43 @@ def fused_experts_with_mc2(
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
"ep_world_size": ep_world_size,
"ep_rank_id": ep_rank_id,
}
if enable_dispatch_v2:
stage3_kwargs.update({
"assist_info_for_combine":
assist_info_for_combine,
})
else:
stage3_kwargs.update({
"expand_idx": assist_info_for_combine,
})
if need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
if a3_need_extra_args and enable_dispatch_v2:
stage3_kwargs.update({
"x_active_mask": mc2_mask,
})
kwargs_mc2.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
**kwargs_mc2
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
**kwargs_mc2)
if shared_experts is None:
return hidden_states
else:
with npu_stream_switch("moe_secondary", 0):
npu_wait_tensor(shared_act[0], down_out_list)
shared_output, _ = shared_experts.down_proj(shared_act)
npu_wait_tensor(shared_act, down_out_list)
shared_output, _ = shared_experts.down_proj(
(shared_act, swiglu_out_scale))
return hidden_states, shared_output
@@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try:
device_group = self.ep_group.device_group
device_group = get_mc2_group().device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
@@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype)
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill, is_deepseek_v3_r1)
fused_moe_state = get_forward_context().fused_moe_state
if fused_moe_state == FusedMoEState.AllGatherEP:
return fused_experts_with_allgather(
hidden_states=x,
@@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod:
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts)
shared_experts=shared_experts,
is_torchair=self.torchair_graph_enabled,
mc2_mask=kwargs.get("mc2_mask", None))
elif fused_moe_state in [
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
]: