From 23ca68d0c8557a91b7213782de5c7fafa0cfb985 Mon Sep 17 00:00:00 2001 From: zzzzwwjj <34335947+zzzzwwjj@users.noreply.github.com> Date: Tue, 17 Jun 2025 17:49:03 +0800 Subject: [PATCH] [refactor] Refactoring AscendFusedMoE (#1229) ### What this PR does / why we need it? This PR is used for resolved [issue 1147](https://github.com/vllm-project/vllm-ascend/issues/1147) 1. Move fused_moe code into one file `fused_moe.py`. 2. Integrate branch conditions into function `get_fused_moe_state`. ### Does this PR introduce _any_ user-facing change? 1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this env is useless, we can make judgments based on the current scenario without this env, it will only increase complexity. 2. This PR has removed the env `USING_LCCL_COM`, because this env has already expired. 3. `additional_config.expert_tensor_parallel_size` has already expired, and now we also use parameter `enable_expert_parallel`, consistent with the vLLM. ### How was this patch tested? Signed-off-by: zzzzwwjj <1183291235@qq.com> --- vllm_ascend/attention/mla_v1.py | 3 + vllm_ascend/envs.py | 8 -- vllm_ascend/models/deepseek_dbo.py | 45 +------ vllm_ascend/models/deepseek_v2.py | 59 ++------- vllm_ascend/ops/fused_moe.py | 146 +++++++++++++---------- vllm_ascend/quantization/w8a8_dynamic.py | 12 +- vllm_ascend/utils.py | 19 +++ vllm_ascend/worker/model_runner_v1.py | 47 +++----- vllm_ascend/worker/worker_v1.py | 15 ++- 9 files changed, 150 insertions(+), 204 deletions(-) diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index f741508..189aa38 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -136,6 +136,7 @@ class AscendMLAMetadata: # For logging. num_input_tokens: int = 0 # Number of tokens including padding. + max_num_tokens_across_dp: int = 0 with_prefill_across_dp: bool = False query_lens: Optional[list[int]] = None @@ -364,6 +365,7 @@ class AscendMLAMetadataBuilder: common_attn_metadata: CommonAttentionMetadata, common_prefix_len: Optional[int] = None, graph_pad_size: int = -1, + max_num_tokens_across_dp: int = 0, with_prefill_across_dp: bool = False, ) -> AscendMLAMetadata: assert self._num_decodes + self._num_prefills == num_reqs @@ -509,6 +511,7 @@ class AscendMLAMetadataBuilder: query_start_loc=query_start_loc, block_tables=block_table, seq_lens=seq_lens, + max_num_tokens_across_dp=max_num_tokens_across_dp, with_prefill_across_dp=with_prefill_across_dp, ) diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index 8d7c10e..74e9c19 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -50,18 +50,10 @@ env_variables: Dict[str, Callable[[], Any]] = { # value is None, which means the system default C compiler will be used. "C_COMPILER": lambda: os.getenv("C_COMPILER", None), - # Whether to enable MC2 for DeepSeek. If not set, the default value is False. - # MC2 is a fusion operator provided by Ascend to speed up computing and communication. - # Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html - "VLLM_ENABLE_MC2": - lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))), # Whether to enable the topk optimization. It's disabled by default for experimental support # We'll make it enabled by default in the future. "VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))), - # Whether to use LCCL communication. If not set, the default value is False. - "USING_LCCL_COM": - lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))), # The version of the Ascend chip. If not set, the default value is # ASCEND910B1. It's used for package building. Please make sure that the # version is correct. diff --git a/vllm_ascend/models/deepseek_dbo.py b/vllm_ascend/models/deepseek_dbo.py index 9db49cb..6ab0837 100644 --- a/vllm_ascend/models/deepseek_dbo.py +++ b/vllm_ascend/models/deepseek_dbo.py @@ -51,9 +51,9 @@ from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_v2 import \ - DeepseekV2ForCausalLM # ruff: noqa: E501 + DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # ruff: noqa: E501 + yarn_get_mscale # noqa: E501 from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention) @@ -79,7 +79,6 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.utils import dispose_tensor VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO -VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): @@ -189,26 +188,8 @@ class CustomDeepseekDBOMoE(nn.Module): if hasattr(attn_metadata, 'with_prefill_across_dp'): is_prefill = is_prefill or attn_metadata.with_prefill_across_dp - num_tokens, hidden_size = hidden_states.shape - old_hidden_states = hidden_states.clone() - if self.tp_size > 1: - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: - chunks = torch.chunk(hidden_states, self.tp_size, dim=0) - hidden_states = chunks[self.tp_rank] - elif not self.torchair_graph_enabled: - num_padding_tokens = (self.tp_size - - num_tokens % self.tp_size) % self.tp_size - # Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C - if num_padding_tokens > 0: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, num_padding_tokens)) - chunk_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - hidden_states = chunk_hidden_states[self.tp_rank] - # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -220,33 +201,13 @@ class CustomDeepseekDBOMoE(nn.Module): enable_force_load_balance=enable_force_load_balance, ) * self.routed_scaling_factor - if self.tp_size > 1: - if self.torchair_graph_enabled: - if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill: - final_hidden_states = torch.zeros( - [num_tokens, hidden_size], - dtype=self.params_dtype, - device="npu") - dist.all_gather_into_tensor(final_hidden_states, - hidden_states, self.tp_group) - hidden_states = final_hidden_states - else: - hidden_states = tensor_model_parallel_all_reduce( - hidden_states) - else: - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_padding_tokens > 0: - hidden_states = hidden_states[:-num_padding_tokens] - if self.n_shared_experts is not None: shared_output = self.shared_experts(old_hidden_states) if shared_output is not None: hidden_states = hidden_states + shared_output - return hidden_states.view(num_tokens, hidden_size) + return hidden_states # ----------------------------------------- TBO-related -------------------------------------------- def _forward_ms_op_shared_expert( diff --git a/vllm_ascend/models/deepseek_v2.py b/vllm_ascend/models/deepseek_v2.py index 0ae1142..e96b2e9 100644 --- a/vllm_ascend/models/deepseek_v2.py +++ b/vllm_ascend/models/deepseek_v2.py @@ -28,7 +28,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -import torch.distributed as dist import torch_npu import vllm.envs as envs from torch import nn @@ -37,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.distributed import (get_pp_group, get_tensor_model_parallel_world_size, - get_tp_group, tensor_model_parallel_all_reduce) + get_tp_group) from vllm.distributed.parallel_state import get_dp_group from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul @@ -54,9 +53,9 @@ from vllm.model_executor.layers.sampler import get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.models.deepseek_v2 import \ - DeepseekV2ForCausalLM # ruff: noqa: E501 + DeepseekV2ForCausalLM # noqa: E501 from vllm.model_executor.models.deepseek_v2 import \ - yarn_get_mscale # ruff: noqa: E501 + yarn_get_mscale # noqa: E501 from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, DeepseekV2DecoderLayer, DeepseekV2MLAAttention) @@ -65,7 +64,6 @@ from vllm.model_executor.models.utils import ( maybe_prefix) from vllm.sequence import IntermediateTensors -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import AscendFusedMoE @@ -74,8 +72,6 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, npu_wait_tensor) -VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 - class CustomDeepseekV2SiluAndMul(SiluAndMul): @@ -240,9 +236,8 @@ class CustomDeepseekV2MoE(nn.Module): ascend_config = get_ascend_config() self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled - # NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on self.enable_multistream_moe = \ - ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2 + ascend_config.torchair_graph_config.enable_multistream_moe self.gate = ReplicatedLinear(config.hidden_size, config.n_routed_experts, @@ -312,22 +307,6 @@ class CustomDeepseekV2MoE(nn.Module): enable_force_load_balance = False if hasattr(attn_metadata, 'with_prefill_across_dp'): is_prefill = is_prefill or attn_metadata.with_prefill_across_dp - num_tokens, hidden_size = hidden_states.shape - old_hidden_states = hidden_states - use_separated_shared_experts = (self.shared_experts is not None - and not self.enable_multistream_moe) - - if self.tp_size > 1: - if (VLLM_ENABLE_MC2 - and not is_prefill) or not (self.torchair_graph_enabled or - self.ep_group.world_size == 1): - if num_tokens < self.tp_size: - hidden_states = nn.functional.pad( - hidden_states, (0, 0, 0, self.tp_size - num_tokens)) - chunk_hidden_states = torch.tensor_split(hidden_states, - self.tp_size, - dim=0) - hidden_states = chunk_hidden_states[self.tp_rank] # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) @@ -338,34 +317,14 @@ class CustomDeepseekV2MoE(nn.Module): is_prefill=is_prefill, top_k=CustomDeepseekV2MoE.top_k, enable_force_load_balance=enable_force_load_balance, - shared_experts=(self.shared_experts - if not use_separated_shared_experts else None), + shared_experts=self.shared_experts, ) - if not isinstance(experts_hidden_states, tuple): - hidden_states = experts_hidden_states * self.routed_scaling_factor - else: - hidden_states = ( - experts_hidden_states[0] * self.routed_scaling_factor + - experts_hidden_states[1]) + hidden_states = ( + experts_hidden_states[0] * self.routed_scaling_factor + + experts_hidden_states[1]) - if self.tp_size > 1: - if (VLLM_ENABLE_MC2 - and not is_prefill) or not (self.torchair_graph_enabled or - self.ep_group.world_size == 1): - dist.all_gather(list(chunk_hidden_states), hidden_states, - self.tp_group) - hidden_states = torch.cat(chunk_hidden_states, dim=0) - if num_tokens < self.tp_size: - hidden_states = hidden_states[:num_tokens] - else: - hidden_states = tensor_model_parallel_all_reduce(hidden_states) - - if use_separated_shared_experts: - hidden_states = hidden_states + self.shared_experts( - old_hidden_states) - - return hidden_states.view(num_tokens, hidden_size) + return hidden_states class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): diff --git a/vllm_ascend/ops/fused_moe.py b/vllm_ascend/ops/fused_moe.py index d6115d3..4a4b488 100644 --- a/vllm_ascend/ops/fused_moe.py +++ b/vllm_ascend/ops/fused_moe.py @@ -21,11 +21,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union import torch import torch.distributed as dist import torch_npu +from torch import nn from vllm.config import get_current_vllm_config -from vllm.distributed import (GroupCoordinator, +from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) -from vllm.distributed.parallel_state import get_dp_group +from vllm.distributed.parallel_state import get_dp_group, get_tp_group +from vllm.forward_context import get_forward_context from vllm.model_executor.layers.fused_moe.layer import ( FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, determine_expert_map) @@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer -from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor +from vllm_ascend.utils import (FusedMoEState, dispose_tensor, + get_fused_moe_state, npu_stream_switch, + npu_wait_tensor) -VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 -USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER @@ -845,8 +847,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): super().__init__(moe=moe) vllm_config = get_current_vllm_config() - ep_group = get_ep_group() - self.ep_size = ep_group.world_size + self.ep_group = get_ep_group() + self.ep_size = self.ep_group.world_size self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.local_batch_size = self.global_batch_size // self.ep_size self.max_model_len = vllm_config.model_config.max_model_len @@ -855,7 +857,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled try: - device_group = ep_group.device_group + device_group = self.ep_group.device_group # TODO: Try local_rank = ep_group.rank_in_group local_rank = torch.distributed.get_rank(group=device_group) backend = device_group._get_backend(torch.device("npu")) @@ -931,7 +933,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): if enable_force_load_balance: topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) - if VLLM_ENABLE_MC2 and not is_prefill: + fused_moe_state = get_fused_moe_state(self.ep_group.world_size, + is_prefill) + if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -942,7 +946,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): expert_map=expert_map, moe_all_to_all_group_name=self.moe_all_to_all_group_name, shared_experts=shared_experts) - elif self.torchair_graph_enabled or get_ep_group().world_size == 1: + elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, @@ -1022,9 +1026,6 @@ class AscendFusedMoE(FusedMoE): get_dp_group().world_size), vllm_parallel_config=vllm_config.parallel_config)) - self.moe_parallel_config.ep_size = get_ep_group().world_size - self.moe_parallel_config.tp_size = get_etp_group().world_size - self.top_k = top_k self.num_experts = num_experts self.global_num_experts = num_experts @@ -1066,10 +1067,9 @@ class AscendFusedMoE(FusedMoE): self.ep_size, get_ep_group().rank_in_group, self.global_num_experts) - self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group - self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group - self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled + self.enable_multistream_moe = \ + ascend_config.torchair_graph_config.enable_multistream_moe if self.scoring_func != "softmax" and not self.use_grouped_topk: raise ValueError("Only softmax scoring function is supported for " @@ -1109,6 +1109,8 @@ class AscendFusedMoE(FusedMoE): moe_quant_params["intermediate_size_full"] = intermediate_size self.ep_group = get_ep_group() + # NOTE: self.tp_group is not expert_tp_group + self.tp_group = get_tp_group().device_group self.quant_method.create_weights(layer=self, **moe_quant_params) def forward(self, @@ -1125,25 +1127,45 @@ class AscendFusedMoE(FusedMoE): else: real_top_k = self.top_k - # MC2 ag/rs broadcast/all_reduce - # prefill_req x x √ - # decode_req √ x √ - # graph_mode √ √ x - if self.dp_size > 1: - if VLLM_ENABLE_MC2 and not is_prefill: - ... - elif self.torchair_graph_enabled: - if USING_LCCL_COM: # type: ignore - hidden_states = get_dp_group().all_gather( - hidden_states, 0, False) - router_logits = get_dp_group().all_gather( - router_logits, 0, False) - elif self.torchair_graph_enabled and not is_prefill: - hidden_states = get_dp_group().all_gather(hidden_states, 0) - router_logits = get_dp_group().all_gather(router_logits, 0) - else: - hidden_states, router_logits = get_ep_group().dispatch( - hidden_states, router_logits) + num_tokens, hidden_size = hidden_states.shape + + fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size, + is_prefill) + if shared_experts: + if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2: + shared_hidden_states = shared_experts(hidden_states) + + tp_size = get_tensor_model_parallel_world_size() + if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + if num_tokens < tp_size: + hidden_states = nn.functional.pad( + hidden_states, (0, 0, 0, tp_size - num_tokens)) + router_logits = nn.functional.pad( + router_logits, (0, 0, 0, tp_size - num_tokens)) + chunk_hidden_states = torch.tensor_split(hidden_states, + tp_size, + dim=0) + chunk_router_logits = torch.tensor_split(router_logits, + tp_size, + dim=0) + tp_rank = get_tensor_model_parallel_rank() + hidden_states = chunk_hidden_states[tp_rank] + router_logits = chunk_router_logits[tp_rank] + if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + # NOTE: When in torchair graph, it has been padded in model_runner_v1 + if not self.torchair_graph_enabled or is_prefill: + attn_metadata = get_forward_context().attn_metadata + if attn_metadata is not None: + max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp + if num_tokens < max_num_tokens_across_dp: + hidden_states = nn.functional.pad( + hidden_states, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + router_logits = nn.functional.pad( + router_logits, + (0, 0, 0, max_num_tokens_across_dp - num_tokens)) + hidden_states = get_dp_group().all_gather(hidden_states, 0) + router_logits = get_dp_group().all_gather(router_logits, 0) # Matrix multiply. e_hidden_states = self.quant_method.apply( @@ -1167,36 +1189,36 @@ class AscendFusedMoE(FusedMoE): shared_experts=shared_experts, ) - if shared_experts is not None: - # Provide dummy implementation of "non-separated" shared experts. - if not isinstance(e_hidden_states, tuple): - return e_hidden_states, shared_experts(hidden_states) - else: - return e_hidden_states + if shared_experts: + if isinstance(e_hidden_states, tuple): + e_hidden_states, shared_hidden_states = e_hidden_states - if self.dp_size > 1: - if VLLM_ENABLE_MC2 and not is_prefill: - ... - elif self.torchair_graph_enabled: - if USING_LCCL_COM: # type: ignore - e_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) - elif self.torchair_graph_enabled and not is_prefill: - e_hidden_states = dist._functional_collectives.reduce_scatter_tensor( - e_hidden_states, - "sum", - scatter_dim=0, - group=get_dp_group().device_group) - else: - e_hidden_states = get_ep_group().combine(e_hidden_states) + if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather: + dist.all_gather(list(chunk_hidden_states), e_hidden_states, + self.tp_group) + final_hidden_states = torch.cat(chunk_hidden_states, dim=0) + if num_tokens < tp_size: + final_hidden_states = final_hidden_states[:num_tokens] + dispose_tensor(e_hidden_states) + elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + final_hidden_states = dist._functional_collectives.reduce_scatter_tensor( + e_hidden_states, + "sum", + scatter_dim=0, + group=get_dp_group().device_group) + final_hidden_states = final_hidden_states[:num_tokens] + dispose_tensor(e_hidden_states) + else: + final_hidden_states = e_hidden_states - if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): - e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states) + if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather: + final_hidden_states = tensor_model_parallel_all_reduce( + final_hidden_states) - return e_hidden_states + if shared_experts: + return final_hidden_states, shared_hidden_states + else: + return final_hidden_states # ----------------------------------------- TBO-related -------------------------------------------- diff --git a/vllm_ascend/quantization/w8a8_dynamic.py b/vllm_ascend/quantization/w8a8_dynamic.py index 66a0a30..6c44a6a 100644 --- a/vllm_ascend/quantization/w8a8_dynamic.py +++ b/vllm_ascend/quantization/w8a8_dynamic.py @@ -22,15 +22,13 @@ import torch.distributed as dist import torch_npu from vllm.distributed import GroupCoordinator -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.ops.fused_moe import select_experts -from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, +from vllm_ascend.utils import (FusedMoEState, dispose_tensor, + get_fused_moe_state, npu_stream_switch, npu_wait_tensor) -VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2 - def apply_mlp(hidden_states: torch.Tensor, w1: torch.Tensor, @@ -660,7 +658,9 @@ class AscendW8A8DynamicFusedMoEMethod: topk_weights = topk_weights.to(x.dtype) - if VLLM_ENABLE_MC2 and not is_prefill: + fused_moe_state = get_fused_moe_state(self.ep_group.world_size, + is_prefill) + if fused_moe_state == FusedMoEState.MC2: return fused_experts_with_mc2( hidden_states=x, w1=layer.w13_weight, @@ -675,7 +675,7 @@ class AscendW8A8DynamicFusedMoEMethod: log2phy=log2phy, global_redundant_expert_num=global_redundant_expert_num, shared_experts=shared_experts) - elif self.torchair_graph_enabled or self.ep_group.world_size == 1: + elif fused_moe_state == FusedMoEState.AllGather: return fused_experts(hidden_states=x, w1=layer.w13_weight, w1_scale=layer.w13_weight_scale, diff --git a/vllm_ascend/utils.py b/vllm_ascend/utils.py index d932053..eeab287 100644 --- a/vllm_ascend/utils.py +++ b/vllm_ascend/utils.py @@ -20,6 +20,7 @@ import atexit import math from contextlib import contextmanager, nullcontext +from enum import Enum from threading import Lock from typing import TYPE_CHECKING, List, Tuple @@ -275,3 +276,21 @@ def npu_wait_tensor(self: torch.Tensor, *, enabled: bool = True): return _npu_wait_tensor(self, dependency) if enabled else self + + +# TODO(zzzzwwjj): move this into forward_context +class FusedMoEState(Enum): + AllGather = 0 + All2All = 1 + MC2 = 2 + + +# TODO(zzzzwwjj): add soc_version to choose branch +def get_fused_moe_state(ep_size: int, with_prefill: bool): + if ep_size == 1: + return FusedMoEState.AllGather + # NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph. + elif ep_size < 16 or with_prefill: + return FusedMoEState.All2All + else: + return FusedMoEState.MC2 diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 6d226da..801819b 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -348,15 +348,10 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.init_torchair_graph_batch_sizes() if len(self.torchair_graph_batch_sizes) == 0: - #If MC2 is enabled, torchair_graph_batch_size should pad to tp_size - if envs_ascend.VLLM_ENABLE_MC2: - self.torchair_graph_batch_sizes = [ - self.scheduler_config.max_num_seqs - ] - else: - self.torchair_graph_batch_sizes = [ - 1, self.scheduler_config.max_num_seqs - ] + # TODO(zzzzwwjj): check torchair_graph_batch_sizes init code + self.torchair_graph_batch_sizes = [ + self.scheduler_config.max_num_seqs + ] torch._dynamo.cache_size.config.cache_size_limit += len( self.torchair_graph_batch_sizes) @@ -569,10 +564,12 @@ class NPUModelRunner(LoRAModelRunnerMixin): self.input_batch.refresh_sampling_metadata() def _get_forward_metadata_across_dp( - self, batch_size: int, with_prefill: bool) -> tuple[int, bool]: - forward_metadata = torch.tensor([batch_size, with_prefill], - device="cpu", - dtype=torch.int32) + self, total_num_scheduled_tokens: int, + with_prefill: bool) -> tuple[int, bool]: + forward_metadata = torch.tensor( + [total_num_scheduled_tokens, with_prefill], + device="cpu", + dtype=torch.int32) dist.all_reduce(forward_metadata, op=ReduceOp.MAX, group=get_dp_group().cpu_group) @@ -901,11 +898,11 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.dp_size > 1: max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( total_num_scheduled_tokens, with_prefill) + extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens extra_builder_kwargs['with_prefill_across_dp'] = with_prefill # Add graph_pad_size here - if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled - and not with_prefill): + if self.torchair_graph_enabled and not with_prefill: if self.dp_size > 1: padded_batch_size = self.select_torchair_padded_batch_size( max_num_tokens) @@ -984,8 +981,7 @@ class NPUModelRunner(LoRAModelRunnerMixin): else: positions = self.positions[:num_input_tokens] - if (envs_ascend.VLLM_ENABLE_MC2 - or self.torchair_graph_enabled) and not with_prefill: + if self.torchair_graph_enabled and not with_prefill: input_ids = self.input_ids[:padded_batch_size] positions = self.positions[:padded_batch_size] @@ -1885,20 +1881,15 @@ class NPUModelRunner(LoRAModelRunnerMixin): return spec_token_ids def init_torchair_graph_batch_sizes(self): + start_graph_batch_size = 4 tp_size = get_tensor_model_parallel_world_size() - batch_size_step = 8 - largest_batch_size = 1 - if envs_ascend.VLLM_ENABLE_MC2: - batch_size_step = max(batch_size_step, tp_size) - largest_batch_size = batch_size_step - while (largest_batch_size < 8): - self.torchair_graph_batch_sizes.append(largest_batch_size) - largest_batch_size *= 2 + # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks + start_graph_batch_size = max(start_graph_batch_size, tp_size) - while (largest_batch_size <= self.scheduler_config.max_num_seqs): - self.torchair_graph_batch_sizes.append(largest_batch_size) - largest_batch_size += batch_size_step + while (start_graph_batch_size <= self.scheduler_config.max_num_seqs): + self.torchair_graph_batch_sizes.append(start_graph_batch_size) + start_graph_batch_size *= 2 def select_torchair_padded_batch_size(self, batch_size: int): selected_batch_size = self.max_num_reqs diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 76844c9..6fe84a4 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -38,7 +38,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.worker.worker_base import WorkerBase -import vllm_ascend.envs as envs_ascend from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel @@ -247,15 +246,15 @@ class NPUWorker(WorkerBase): def execute_dummy_batch(self) -> None: runner = self.model_runner - num_tokens = 1 + max_num_tokens = 1 + with_prefill = False if runner.dp_size > 1: max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( - 1, False) - if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled: - if not with_prefill: - num_tokens = max_num_tokens - num_tokens = runner.select_torchair_padded_batch_size(num_tokens) - runner._dummy_run(num_tokens, + max_num_tokens, with_prefill) + if runner.torchair_graph_enabled and not with_prefill: + max_num_tokens = runner.select_torchair_padded_batch_size( + max_num_tokens) + runner._dummy_run(max_num_tokens, is_compile=False, with_prefill=with_prefill)