[refactor] Refactoring AscendFusedMoE (#1229)

<!--  Thanks for sending a pull request!

BEFORE SUBMITTING, PLEASE READ
https://docs.vllm.ai/en/latest/contributing/overview.html

-->
### What this PR does / why we need it?
This PR is used for resolved [issue
1147](https://github.com/vllm-project/vllm-ascend/issues/1147)
1. Move fused_moe code into one file `fused_moe.py`.
2. Integrate branch conditions into function `get_fused_moe_state`.
<!--
- Please clarify what changes you are proposing. The purpose of this
section is to outline the changes and how this PR fixes the issue.
If possible, please consider writing useful notes for better and faster
reviews in your PR.

- Please clarify why the changes are needed. For instance, the use case
and bug description.

- Fixes #
-->

### Does this PR introduce _any_ user-facing change?
1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this
env is useless, we can make judgments based on the current scenario
without this env, it will only increase complexity.
2. This PR has removed the env `USING_LCCL_COM`, because this env has
already expired.
3. `additional_config.expert_tensor_parallel_size` has already expired,
and now we also use parameter `enable_expert_parallel`, consistent with
the vLLM.
<!--
Note that it means *any* user-facing change including all aspects such
as API, interface or other behavior changes.
Documentation-only updates are not considered user-facing changes.
-->

### How was this patch tested?
<!--
CI passed with new added/existing test.
If it was tested in a way different from regular unit tests, please
clarify how you tested step by step, ideally copy and paste-able, so
that other reviewers can test and check, and descendants can verify in
the future.
If tests were not added, please describe why they were not added and/or
why it was difficult to add.
-->

Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-06-17 17:49:03 +08:00
committed by GitHub
parent 05dec7eda9
commit 23ca68d0c8
9 changed files with 150 additions and 204 deletions

View File

@@ -136,6 +136,7 @@ class AscendMLAMetadata:
# For logging. # For logging.
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
max_num_tokens_across_dp: int = 0
with_prefill_across_dp: bool = False with_prefill_across_dp: bool = False
query_lens: Optional[list[int]] = None query_lens: Optional[list[int]] = None
@@ -364,6 +365,7 @@ class AscendMLAMetadataBuilder:
common_attn_metadata: CommonAttentionMetadata, common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None, common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1, graph_pad_size: int = -1,
max_num_tokens_across_dp: int = 0,
with_prefill_across_dp: bool = False, with_prefill_across_dp: bool = False,
) -> AscendMLAMetadata: ) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs assert self._num_decodes + self._num_prefills == num_reqs
@@ -509,6 +511,7 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc, query_start_loc=query_start_loc,
block_tables=block_table, block_tables=block_table,
seq_lens=seq_lens, seq_lens=seq_lens,
max_num_tokens_across_dp=max_num_tokens_across_dp,
with_prefill_across_dp=with_prefill_across_dp, with_prefill_across_dp=with_prefill_across_dp,
) )

View File

@@ -50,18 +50,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
# value is None, which means the system default C compiler will be used. # value is None, which means the system default C compiler will be used.
"C_COMPILER": "C_COMPILER":
lambda: os.getenv("C_COMPILER", None), lambda: os.getenv("C_COMPILER", None),
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
"VLLM_ENABLE_MC2":
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
# Whether to enable the topk optimization. It's disabled by default for experimental support # Whether to enable the topk optimization. It's disabled by default for experimental support
# We'll make it enabled by default in the future. # We'll make it enabled by default in the future.
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE": "VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))), lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
# Whether to use LCCL communication. If not set, the default value is False.
"USING_LCCL_COM":
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
# The version of the Ascend chip. If not set, the default value is # The version of the Ascend chip. If not set, the default value is
# ASCEND910B1. It's used for package building. Please make sure that the # ASCEND910B1. It's used for package building. Please make sure that the
# version is correct. # version is correct.

View File

@@ -51,9 +51,9 @@ from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_v2 import \ from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # ruff: noqa: E501 DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import \ from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # ruff: noqa: E501 yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MLAAttention) DeepseekV2MLAAttention)
@@ -79,7 +79,6 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.utils import dispose_tensor from vllm_ascend.utils import dispose_tensor
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP): class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
@@ -189,26 +188,8 @@ class CustomDeepseekDBOMoE(nn.Module):
if hasattr(attn_metadata, 'with_prefill_across_dp'): if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
num_tokens, hidden_size = hidden_states.shape
old_hidden_states = hidden_states.clone() old_hidden_states = hidden_states.clone()
if self.tp_size > 1:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.torchair_graph_enabled:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
if num_padding_tokens > 0:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, num_padding_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
hidden_states = chunk_hidden_states[self.tp_rank]
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
@@ -220,33 +201,13 @@ class CustomDeepseekDBOMoE(nn.Module):
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
) * self.routed_scaling_factor ) * self.routed_scaling_factor
if self.tp_size > 1:
if self.torchair_graph_enabled:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states,
hidden_states, self.tp_group)
hidden_states = final_hidden_states
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
else:
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_padding_tokens > 0:
hidden_states = hidden_states[:-num_padding_tokens]
if self.n_shared_experts is not None: if self.n_shared_experts is not None:
shared_output = self.shared_experts(old_hidden_states) shared_output = self.shared_experts(old_hidden_states)
if shared_output is not None: if shared_output is not None:
hidden_states = hidden_states + shared_output hidden_states = hidden_states + shared_output
return hidden_states.view(num_tokens, hidden_size) return hidden_states
# ----------------------------------------- TBO-related -------------------------------------------- # ----------------------------------------- TBO-related --------------------------------------------
def _forward_ms_op_shared_expert( def _forward_ms_op_shared_expert(

View File

@@ -28,7 +28,6 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist
import torch_npu import torch_npu
import vllm.envs as envs import vllm.envs as envs
from torch import nn from torch import nn
@@ -37,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, ModelConfig, VllmConfig from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group, from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
get_tp_group, tensor_model_parallel_all_reduce) get_tp_group)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
@@ -54,9 +53,9 @@ from vllm.model_executor.layers.sampler import get_sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding) ParallelLMHead, VocabParallelEmbedding)
from vllm.model_executor.models.deepseek_v2 import \ from vllm.model_executor.models.deepseek_v2 import \
DeepseekV2ForCausalLM # ruff: noqa: E501 DeepseekV2ForCausalLM # noqa: E501
from vllm.model_executor.models.deepseek_v2 import \ from vllm.model_executor.models.deepseek_v2 import \
yarn_get_mscale # ruff: noqa: E501 yarn_get_mscale # noqa: E501
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention, from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
DeepseekV2DecoderLayer, DeepseekV2DecoderLayer,
DeepseekV2MLAAttention) DeepseekV2MLAAttention)
@@ -65,7 +64,6 @@ from vllm.model_executor.models.utils import (
maybe_prefix) maybe_prefix)
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE from vllm_ascend.ops.fused_moe import AscendFusedMoE
@@ -74,8 +72,6 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
npu_wait_tensor) npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
class CustomDeepseekV2SiluAndMul(SiluAndMul): class CustomDeepseekV2SiluAndMul(SiluAndMul):
@@ -240,9 +236,8 @@ class CustomDeepseekV2MoE(nn.Module):
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_moe = \ self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2 ascend_config.torchair_graph_config.enable_multistream_moe
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
config.n_routed_experts, config.n_routed_experts,
@@ -312,22 +307,6 @@ class CustomDeepseekV2MoE(nn.Module):
enable_force_load_balance = False enable_force_load_balance = False
if hasattr(attn_metadata, 'with_prefill_across_dp'): if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
num_tokens, hidden_size = hidden_states.shape
old_hidden_states = hidden_states
use_separated_shared_experts = (self.shared_experts is not None
and not self.enable_multistream_moe)
if self.tp_size > 1:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if num_tokens < self.tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
hidden_states = chunk_hidden_states[self.tp_rank]
# router_logits: (num_tokens, n_experts) # router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states) router_logits, _ = self.gate(hidden_states)
@@ -338,34 +317,14 @@ class CustomDeepseekV2MoE(nn.Module):
is_prefill=is_prefill, is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k, top_k=CustomDeepseekV2MoE.top_k,
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
shared_experts=(self.shared_experts shared_experts=self.shared_experts,
if not use_separated_shared_experts else None),
) )
if not isinstance(experts_hidden_states, tuple):
hidden_states = experts_hidden_states * self.routed_scaling_factor
else:
hidden_states = ( hidden_states = (
experts_hidden_states[0] * self.routed_scaling_factor + experts_hidden_states[0] * self.routed_scaling_factor +
experts_hidden_states[1]) experts_hidden_states[1])
if self.tp_size > 1: return hidden_states
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < self.tp_size:
hidden_states = hidden_states[:num_tokens]
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if use_separated_shared_experts:
hidden_states = hidden_states + self.shared_experts(
old_hidden_states)
return hidden_states.view(num_tokens, hidden_size)
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention): class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):

View File

@@ -21,11 +21,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch_npu import torch_npu
from torch import nn
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (GroupCoordinator, from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce) tensor_model_parallel_all_reduce)
from vllm.distributed.parallel_state import get_dp_group from vllm.distributed.parallel_state import get_dp_group, get_tp_group
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod, FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
determine_expert_map) determine_expert_map)
@@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
@@ -845,8 +847,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
super().__init__(moe=moe) super().__init__(moe=moe)
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
ep_group = get_ep_group() self.ep_group = get_ep_group()
self.ep_size = ep_group.world_size self.ep_size = self.ep_group.world_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size self.local_batch_size = self.global_batch_size // self.ep_size
self.max_model_len = vllm_config.model_config.max_model_len self.max_model_len = vllm_config.model_config.max_model_len
@@ -855,7 +857,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
try: try:
device_group = ep_group.device_group device_group = self.ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group) local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu")) backend = device_group._get_backend(torch.device("npu"))
@@ -931,7 +933,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
if enable_force_load_balance: if enable_force_load_balance:
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts) topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
if VLLM_ENABLE_MC2 and not is_prefill: fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2( return fused_experts_with_mc2(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -942,7 +946,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
expert_map=expert_map, expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name, moe_all_to_all_group_name=self.moe_all_to_all_group_name,
shared_experts=shared_experts) shared_experts=shared_experts)
elif self.torchair_graph_enabled or get_ep_group().world_size == 1: elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w2=layer.w2_weight, w2=layer.w2_weight,
@@ -1022,9 +1026,6 @@ class AscendFusedMoE(FusedMoE):
get_dp_group().world_size), get_dp_group().world_size),
vllm_parallel_config=vllm_config.parallel_config)) vllm_parallel_config=vllm_config.parallel_config))
self.moe_parallel_config.ep_size = get_ep_group().world_size
self.moe_parallel_config.tp_size = get_etp_group().world_size
self.top_k = top_k self.top_k = top_k
self.num_experts = num_experts self.num_experts = num_experts
self.global_num_experts = num_experts self.global_num_experts = num_experts
@@ -1066,10 +1067,9 @@ class AscendFusedMoE(FusedMoE):
self.ep_size, self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts) get_ep_group().rank_in_group, self.global_num_experts)
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
self.enable_multistream_moe = \
ascend_config.torchair_graph_config.enable_multistream_moe
if self.scoring_func != "softmax" and not self.use_grouped_topk: if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for " raise ValueError("Only softmax scoring function is supported for "
@@ -1109,6 +1109,8 @@ class AscendFusedMoE(FusedMoE):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.ep_group = get_ep_group() self.ep_group = get_ep_group()
# NOTE: self.tp_group is not expert_tp_group
self.tp_group = get_tp_group().device_group
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self, def forward(self,
@@ -1125,25 +1127,45 @@ class AscendFusedMoE(FusedMoE):
else: else:
real_top_k = self.top_k real_top_k = self.top_k
# MC2 ag/rs broadcast/all_reduce num_tokens, hidden_size = hidden_states.shape
# prefill_req x x √
# decode_req √ x √ fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
# graph_mode x is_prefill)
if self.dp_size > 1: if shared_experts:
if VLLM_ENABLE_MC2 and not is_prefill: if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
... shared_hidden_states = shared_experts(hidden_states)
elif self.torchair_graph_enabled:
if USING_LCCL_COM: # type: ignore tp_size = get_tensor_model_parallel_world_size()
hidden_states = get_dp_group().all_gather( if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
hidden_states, 0, False) if num_tokens < tp_size:
router_logits = get_dp_group().all_gather( hidden_states = nn.functional.pad(
router_logits, 0, False) hidden_states, (0, 0, 0, tp_size - num_tokens))
elif self.torchair_graph_enabled and not is_prefill: router_logits = nn.functional.pad(
router_logits, (0, 0, 0, tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
tp_size,
dim=0)
chunk_router_logits = torch.tensor_split(router_logits,
tp_size,
dim=0)
tp_rank = get_tensor_model_parallel_rank()
hidden_states = chunk_hidden_states[tp_rank]
router_logits = chunk_router_logits[tp_rank]
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
# NOTE: When in torchair graph, it has been padded in model_runner_v1
if not self.torchair_graph_enabled or is_prefill:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata is not None:
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
if num_tokens < max_num_tokens_across_dp:
hidden_states = nn.functional.pad(
hidden_states,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
router_logits = nn.functional.pad(
router_logits,
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
hidden_states = get_dp_group().all_gather(hidden_states, 0) hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0) router_logits = get_dp_group().all_gather(router_logits, 0)
else:
hidden_states, router_logits = get_ep_group().dispatch(
hidden_states, router_logits)
# Matrix multiply. # Matrix multiply.
e_hidden_states = self.quant_method.apply( e_hidden_states = self.quant_method.apply(
@@ -1167,36 +1189,36 @@ class AscendFusedMoE(FusedMoE):
shared_experts=shared_experts, shared_experts=shared_experts,
) )
if shared_experts is not None: if shared_experts:
# Provide dummy implementation of "non-separated" shared experts. if isinstance(e_hidden_states, tuple):
if not isinstance(e_hidden_states, tuple): e_hidden_states, shared_hidden_states = e_hidden_states
return e_hidden_states, shared_experts(hidden_states)
else:
return e_hidden_states
if self.dp_size > 1: if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
if VLLM_ENABLE_MC2 and not is_prefill: dist.all_gather(list(chunk_hidden_states), e_hidden_states,
... self.tp_group)
elif self.torchair_graph_enabled: final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if USING_LCCL_COM: # type: ignore if num_tokens < tp_size:
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor( final_hidden_states = final_hidden_states[:num_tokens]
e_hidden_states, dispose_tensor(e_hidden_states)
"sum", elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
scatter_dim=0, final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
group=get_dp_group().device_group)
elif self.torchair_graph_enabled and not is_prefill:
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
e_hidden_states, e_hidden_states,
"sum", "sum",
scatter_dim=0, scatter_dim=0,
group=get_dp_group().device_group) group=get_dp_group().device_group)
final_hidden_states = final_hidden_states[:num_tokens]
dispose_tensor(e_hidden_states)
else: else:
e_hidden_states = get_ep_group().combine(e_hidden_states) final_hidden_states = e_hidden_states
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1): if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states) final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return e_hidden_states if shared_experts:
return final_hidden_states, shared_hidden_states
else:
return final_hidden_states
# ----------------------------------------- TBO-related -------------------------------------------- # ----------------------------------------- TBO-related --------------------------------------------

View File

@@ -22,15 +22,13 @@ import torch.distributed as dist
import torch_npu import torch_npu
from vllm.distributed import GroupCoordinator from vllm.distributed import GroupCoordinator
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import select_experts from vllm_ascend.ops.fused_moe import select_experts
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch, from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
get_fused_moe_state, npu_stream_switch,
npu_wait_tensor) npu_wait_tensor)
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
def apply_mlp(hidden_states: torch.Tensor, def apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor, w1: torch.Tensor,
@@ -660,7 +658,9 @@ class AscendW8A8DynamicFusedMoEMethod:
topk_weights = topk_weights.to(x.dtype) topk_weights = topk_weights.to(x.dtype)
if VLLM_ENABLE_MC2 and not is_prefill: fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
is_prefill)
if fused_moe_state == FusedMoEState.MC2:
return fused_experts_with_mc2( return fused_experts_with_mc2(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
@@ -675,7 +675,7 @@ class AscendW8A8DynamicFusedMoEMethod:
log2phy=log2phy, log2phy=log2phy,
global_redundant_expert_num=global_redundant_expert_num, global_redundant_expert_num=global_redundant_expert_num,
shared_experts=shared_experts) shared_experts=shared_experts)
elif self.torchair_graph_enabled or self.ep_group.world_size == 1: elif fused_moe_state == FusedMoEState.AllGather:
return fused_experts(hidden_states=x, return fused_experts(hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
w1_scale=layer.w13_weight_scale, w1_scale=layer.w13_weight_scale,

View File

@@ -20,6 +20,7 @@
import atexit import atexit
import math import math
from contextlib import contextmanager, nullcontext from contextlib import contextmanager, nullcontext
from enum import Enum
from threading import Lock from threading import Lock
from typing import TYPE_CHECKING, List, Tuple from typing import TYPE_CHECKING, List, Tuple
@@ -275,3 +276,21 @@ def npu_wait_tensor(self: torch.Tensor,
*, *,
enabled: bool = True): enabled: bool = True):
return _npu_wait_tensor(self, dependency) if enabled else self return _npu_wait_tensor(self, dependency) if enabled else self
# TODO(zzzzwwjj): move this into forward_context
class FusedMoEState(Enum):
AllGather = 0
All2All = 1
MC2 = 2
# TODO(zzzzwwjj): add soc_version to choose branch
def get_fused_moe_state(ep_size: int, with_prefill: bool):
if ep_size == 1:
return FusedMoEState.AllGather
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
elif ep_size < 16 or with_prefill:
return FusedMoEState.All2All
else:
return FusedMoEState.MC2

View File

@@ -348,15 +348,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.init_torchair_graph_batch_sizes() self.init_torchair_graph_batch_sizes()
if len(self.torchair_graph_batch_sizes) == 0: if len(self.torchair_graph_batch_sizes) == 0:
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size # TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
if envs_ascend.VLLM_ENABLE_MC2:
self.torchair_graph_batch_sizes = [ self.torchair_graph_batch_sizes = [
self.scheduler_config.max_num_seqs self.scheduler_config.max_num_seqs
] ]
else:
self.torchair_graph_batch_sizes = [
1, self.scheduler_config.max_num_seqs
]
torch._dynamo.cache_size.config.cache_size_limit += len( torch._dynamo.cache_size.config.cache_size_limit += len(
self.torchair_graph_batch_sizes) self.torchair_graph_batch_sizes)
@@ -569,8 +564,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
self.input_batch.refresh_sampling_metadata() self.input_batch.refresh_sampling_metadata()
def _get_forward_metadata_across_dp( def _get_forward_metadata_across_dp(
self, batch_size: int, with_prefill: bool) -> tuple[int, bool]: self, total_num_scheduled_tokens: int,
forward_metadata = torch.tensor([batch_size, with_prefill], with_prefill: bool) -> tuple[int, bool]:
forward_metadata = torch.tensor(
[total_num_scheduled_tokens, with_prefill],
device="cpu", device="cpu",
dtype=torch.int32) dtype=torch.int32)
dist.all_reduce(forward_metadata, dist.all_reduce(forward_metadata,
@@ -901,11 +898,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
if self.dp_size > 1: if self.dp_size > 1:
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp( max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
total_num_scheduled_tokens, with_prefill) total_num_scheduled_tokens, with_prefill)
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
# Add graph_pad_size here # Add graph_pad_size here
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled if self.torchair_graph_enabled and not with_prefill:
and not with_prefill):
if self.dp_size > 1: if self.dp_size > 1:
padded_batch_size = self.select_torchair_padded_batch_size( padded_batch_size = self.select_torchair_padded_batch_size(
max_num_tokens) max_num_tokens)
@@ -984,8 +981,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
else: else:
positions = self.positions[:num_input_tokens] positions = self.positions[:num_input_tokens]
if (envs_ascend.VLLM_ENABLE_MC2 if self.torchair_graph_enabled and not with_prefill:
or self.torchair_graph_enabled) and not with_prefill:
input_ids = self.input_ids[:padded_batch_size] input_ids = self.input_ids[:padded_batch_size]
positions = self.positions[:padded_batch_size] positions = self.positions[:padded_batch_size]
@@ -1885,20 +1881,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
return spec_token_ids return spec_token_ids
def init_torchair_graph_batch_sizes(self): def init_torchair_graph_batch_sizes(self):
start_graph_batch_size = 4
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
batch_size_step = 8
largest_batch_size = 1
if envs_ascend.VLLM_ENABLE_MC2: # NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
batch_size_step = max(batch_size_step, tp_size) start_graph_batch_size = max(start_graph_batch_size, tp_size)
largest_batch_size = batch_size_step
while (largest_batch_size < 8):
self.torchair_graph_batch_sizes.append(largest_batch_size)
largest_batch_size *= 2
while (largest_batch_size <= self.scheduler_config.max_num_seqs): while (start_graph_batch_size <= self.scheduler_config.max_num_seqs):
self.torchair_graph_batch_sizes.append(largest_batch_size) self.torchair_graph_batch_sizes.append(start_graph_batch_size)
largest_batch_size += batch_size_step start_graph_batch_size *= 2
def select_torchair_padded_batch_size(self, batch_size: int): def select_torchair_padded_batch_size(self, batch_size: int):
selected_batch_size = self.max_num_reqs selected_batch_size = self.max_num_reqs

View File

@@ -38,7 +38,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ModelRunnerOutput from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.worker.worker_base import WorkerBase from vllm.v1.worker.worker_base import WorkerBase
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import init_ascend_config from vllm_ascend.ascend_config import init_ascend_config
from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.device_allocator.camem import CaMemAllocator
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
@@ -247,15 +246,15 @@ class NPUWorker(WorkerBase):
def execute_dummy_batch(self) -> None: def execute_dummy_batch(self) -> None:
runner = self.model_runner runner = self.model_runner
num_tokens = 1 max_num_tokens = 1
with_prefill = False
if runner.dp_size > 1: if runner.dp_size > 1:
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp( max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
1, False) max_num_tokens, with_prefill)
if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled: if runner.torchair_graph_enabled and not with_prefill:
if not with_prefill: max_num_tokens = runner.select_torchair_padded_batch_size(
num_tokens = max_num_tokens max_num_tokens)
num_tokens = runner.select_torchair_padded_batch_size(num_tokens) runner._dummy_run(max_num_tokens,
runner._dummy_run(num_tokens,
is_compile=False, is_compile=False,
with_prefill=with_prefill) with_prefill=with_prefill)