[main][refactor] Refactoring forward_context and model_runner_v1 (#1979)
### What this PR does / why we need it?
A refactoring of forward_context and model_runner_v1, add some context
which is necessary in model inference into forward_context, and refactor
dummy_run logic, make it more reasonable.
Some details for this PR:
Add `ascend_forward_context`;
Update mc2_v2 op, and support `active_mask` param;
Update scripts in examples dir;
refactor `dummy_run` logic;
Add soc_version for A2 and A3;
### Does this PR introduce _any_ user-facing change?
No change at user-facing.
### How was this patch tested?
- vLLM version: v0.10.0
- vLLM main:
57c22e57f9
Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -40,12 +40,15 @@ from vllm.model_executor.layers.quantization.base_config import \
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||
from vllm_ascend.distributed.communication_op import \
|
||||
data_parallel_reduce_scatter
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_all_reduce_merge_state, get_fused_moe_state,
|
||||
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
|
||||
get_all_reduce_merge_state,
|
||||
get_ascend_soc_version,
|
||||
get_rm_router_logits_state, is_310p)
|
||||
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
@@ -127,9 +130,23 @@ def fused_experts_with_mc2(
|
||||
moe_parallel_config: FusedMoEParallelConfig,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: Optional[str] = None,
|
||||
shared_experts: Optional[Any] = None
|
||||
shared_experts: Optional[Any] = None,
|
||||
is_torchair: bool = False,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
global_bs = 0
|
||||
quant_mode = 0
|
||||
ep_rank_id = moe_parallel_config.ep_rank
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
# NOTE: Currently, when in A3 or in torchair graph, we need to pass in some extra param into dispatch & combine
|
||||
need_extra_args = (get_ascend_soc_version() == AscendSocVersion.A3
|
||||
or is_torchair)
|
||||
|
||||
# NOTE: Currently, when in A3, we need to pass in some extra param into dispatch & combine
|
||||
a3_need_extra_args = get_ascend_soc_version() == AscendSocVersion.A3
|
||||
|
||||
enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
|
||||
moe_expert_num = len(expert_map)
|
||||
kwargs_mc2 = {
|
||||
"x": hidden_states,
|
||||
@@ -137,32 +154,35 @@ def fused_experts_with_mc2(
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
"moe_expert_num": moe_expert_num,
|
||||
"global_bs": global_bs,
|
||||
"global_bs": 0,
|
||||
}
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
quant_mode = 0
|
||||
ep_rank_id = moe_parallel_config.ep_rank
|
||||
ep_world_size = moe_parallel_config.ep_size
|
||||
|
||||
tp_world_size = moe_parallel_config.tp_size
|
||||
tp_rank = rank % tp_world_size
|
||||
|
||||
stage1_kwargs = {
|
||||
"scales": None,
|
||||
"quant_mode": quant_mode,
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_world_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
if need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage1_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if shared_experts is not None:
|
||||
@@ -205,7 +225,6 @@ def fused_experts_with_mc2(
|
||||
kwargs_mc2 = {
|
||||
"expand_x": down_out_list,
|
||||
"expert_ids": topk_ids,
|
||||
"expand_idx": expand_idx,
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
"expert_shard_type": 0,
|
||||
"shared_expert_rank_num": 0,
|
||||
@@ -218,15 +237,33 @@ def fused_experts_with_mc2(
|
||||
"group_ep": moe_all_to_all_group_name,
|
||||
"ep_world_size": ep_world_size,
|
||||
"ep_rank_id": ep_rank_id,
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
# "group_tp": self.moe_rs_group_name,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": tp_world_size,
|
||||
"tp_rank_id": tp_rank,
|
||||
}
|
||||
if enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"assist_info_for_combine":
|
||||
assist_info_for_combine,
|
||||
})
|
||||
else:
|
||||
stage3_kwargs.update({
|
||||
"expand_idx": assist_info_for_combine,
|
||||
})
|
||||
if need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
if a3_need_extra_args and enable_dispatch_v2:
|
||||
stage3_kwargs.update({
|
||||
"x_active_mask": mc2_mask,
|
||||
})
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine_v2(
|
||||
**kwargs_mc2
|
||||
) if enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(
|
||||
**kwargs_mc2)
|
||||
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
@@ -981,17 +1018,14 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
super().__init__(moe=moe)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
self.ep_size = self.moe.moe_parallel_config.ep_size
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
try:
|
||||
device_group = self.ep_group.device_group
|
||||
device_group = get_mc2_group().device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
@@ -1074,8 +1108,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.ep_size, is_prefill,
|
||||
is_deepseek_v3_r1)
|
||||
fused_moe_state = get_forward_context().fused_moe_state
|
||||
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
@@ -1087,7 +1121,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
shared_experts=shared_experts)
|
||||
shared_experts=shared_experts,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
elif fused_moe_state in [
|
||||
FusedMoEState.AllGather, FusedMoEState.NaiveMulticast
|
||||
]:
|
||||
@@ -1295,52 +1330,56 @@ class AscendFusedMoE(FusedMoE):
|
||||
real_top_k = self.top_k
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
is_deepseek_v3_r1 = self.global_num_experts == 256
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
||||
is_prefill, is_deepseek_v3_r1)
|
||||
forward_context = get_forward_context()
|
||||
fused_moe_state = forward_context.fused_moe_state
|
||||
mc2_mask = forward_context.mc2_mask
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
# When all_reduce_merge is in progress, shared_experts does not do all_reduce in mlp, but waits until shared_experts+router_experts are completed before doing all_reduce
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if (tp_size > 1 and fused_moe_state not in [
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
if num_tokens < tp_size:
|
||||
if fused_moe_state in {FusedMoEState.MC2}:
|
||||
padding_size = forward_context.padded_num_tokens
|
||||
else:
|
||||
# TODO: Determine if we can remove the padding
|
||||
padding_size = tp_size
|
||||
if num_tokens < padding_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
hidden_states, (0, 0, 0, padding_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, tp_size - num_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
router_logits, (0, 0, 0, padding_size - num_tokens))
|
||||
if tp_size > 1:
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_mc2_mask = torch.tensor_split(mc2_mask, tp_size, dim=0)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
mc2_mask = chunk_mc2_mask[tp_rank]
|
||||
|
||||
if self.dp_size > 1:
|
||||
if fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
if not self.torchair_graph_enabled:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||
if num_tokens < max_num_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0,
|
||||
max_num_tokens_across_dp - num_tokens))
|
||||
max_tokens_across_dp = forward_context.max_tokens_across_dp
|
||||
if num_tokens < max_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
if not self.rm_router_logits:
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
if self.rm_router_logits:
|
||||
router_logits, _ = gate(hidden_states)
|
||||
@@ -1379,20 +1418,24 @@ class AscendFusedMoE(FusedMoE):
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
shared_experts=shared_experts if self.torchair_graph_enabled
|
||||
and self.enable_multistream_moe and not is_prefill else None,
|
||||
mc2_mask=mc2_mask,
|
||||
)
|
||||
|
||||
if shared_experts:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if (tp_size > 1 and fused_moe_state not in [
|
||||
if (fused_moe_state not in [
|
||||
FusedMoEState.AllGather, FusedMoEState.AllGatherEP,
|
||||
FusedMoEState.NaiveMulticast
|
||||
] and not replace_allreduce):
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < tp_size:
|
||||
if tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
if num_tokens < padding_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1:
|
||||
|
||||
Reference in New Issue
Block a user