[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -22,8 +22,7 @@ from typing import Any, Dict, List, Optional
|
||||
|
||||
from vllm.logger import logger
|
||||
|
||||
from .func_wrapper import (wrapper_load_model, wrapper_rmsnorm_forward_oot,
|
||||
wrapper_rmsnorm_init)
|
||||
from .func_wrapper import wrapper_rmsnorm_forward_oot, wrapper_rmsnorm_init
|
||||
from .w8a8 import (AscendC8KVCacheMethod, AscendW8A8FusedMoEMethod,
|
||||
AscendW8A8LinearMethod)
|
||||
from .w8a8_dynamic import (AscendW8A8DynamicFusedMoEMethod,
|
||||
@@ -81,9 +80,6 @@ class VLLMAscendQuantizer:
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm.model_executor.layers.layernorm.RMSNorm",
|
||||
"forward_oot", [wrapper_rmsnorm_forward_oot])
|
||||
VLLMAscendQuantizer.apply_patch(
|
||||
"vllm_ascend.worker.model_runner.NPUModelRunnerBase",
|
||||
"load_model", [wrapper_load_model])
|
||||
break
|
||||
VLLMAscendQuantizer.patched = True
|
||||
logger.info("Using the vLLM Ascend Quantizer version now!")
|
||||
|
||||
@@ -20,15 +20,17 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import GroupCoordinator
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
from vllm.distributed import GroupCoordinator, get_ep_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, FusedMoEState,
|
||||
dispose_tensor, get_fused_moe_state)
|
||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, AscendSocVersion,
|
||||
dispose_tensor, get_ascend_soc_version)
|
||||
|
||||
|
||||
def apply_mlp(hidden_states: torch.Tensor,
|
||||
@@ -118,10 +120,29 @@ def fused_experts_with_mc2(
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
is_torchair: bool = False,
|
||||
quantized_x_for_share: Optional[Any] = None,
|
||||
dynamic_scale_for_share: Optional[Any] = None,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
assert mc2_mask is not None
|
||||
if log2phy is not None:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
global_bs = 0
|
||||
|
||||
quant_mode = 2
|
||||
ep_group = get_mc2_group()
|
||||
ep_rank_id = ep_group.rank_in_group
|
||||
ep_world_size = ep_group.world_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
if (expert_map is not None):
|
||||
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
||||
else:
|
||||
@@ -133,47 +154,43 @@ def fused_experts_with_mc2(
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
quant_mode = 2
|
||||
ep_group = get_ep_group().device_group
|
||||
local_rank = torch.distributed.get_rank(group=ep_group)
|
||||
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
||||
|
||||
world_size = torch.distributed.get_world_size()
|
||||
tp_size = world_size // all_to_all_group_size
|
||||
tp_rank = rank % tp_size
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
}
|
||||
if need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts, _, expand_scales = output[
|
||||
0:7]
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
npu_wait_tensor(shared_gate_up[0], expand_x)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
npu_wait_tensor(quantized_x_for_share, expand_x)
|
||||
shared_act_out = shared_experts.act_fn(
|
||||
(quantized_x_for_share, dynamic_scale_for_share))
|
||||
shared_act, swiglu_out_scale = shared_act_out[0], shared_act_out[1]
|
||||
|
||||
# `expand_x` will be disposed in the `apply_mlp` function
|
||||
down_out_list = apply_mlp(expand_x,
|
||||
w1,
|
||||
w1_scale,
|
||||
@@ -186,13 +203,11 @@ def fused_experts_with_mc2(
|
||||
kwargs_mc2 = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": 0,
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
@@ -200,24 +215,43 @@ def fused_experts_with_mc2(
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": all_to_all_group_size,
|
||||
"ep_rank_id": local_rank,
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
}
|
||||
if enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": assist_info_for_combine,
|
||||
})
|
||||
if need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act[0], down_out_list)
|
||||
shared_output, _ = shared_experts.down_proj(shared_act)
|
||||
npu_wait_tensor(shared_act, down_out_list)
|
||||
shared_output, _ = shared_experts.down_proj(
|
||||
(shared_act, swiglu_out_scale))
|
||||
return hidden_states, shared_output
|
||||
|
||||
|
||||
@@ -640,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
try:
|
||||
device_group = self.ep_group.device_group
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
@@ -755,8 +789,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
if fused_moe_state == FusedMoEState.AllGatherEP:
|
||||
return fused_experts_with_allgather(
|
||||
hidden_states=x,
|
||||
@@ -782,7 +815,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts)
|
||||
shared_experts=shared_experts,
|
||||
is_torchair=self.torchair_graph_enabled,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
|
||||
Reference in New Issue
Block a user