[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR is used for resolved [issue 1147](https://github.com/vllm-project/vllm-ascend/issues/1147) 1. Move fused_moe code into one file `fused_moe.py`. 2. Integrate branch conditions into function `get_fused_moe_state`. <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? 1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this env is useless, we can make judgments based on the current scenario without this env, it will only increase complexity. 2. This PR has removed the env `USING_LCCL_COM`, because this env has already expired. 3. `additional_config.expert_tensor_parallel_size` has already expired, and now we also use parameter `enable_expert_parallel`, consistent with the vLLM. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -136,6 +136,7 @@ class AscendMLAMetadata:
|
||||
# For logging.
|
||||
num_input_tokens: int = 0 # Number of tokens including padding.
|
||||
|
||||
max_num_tokens_across_dp: int = 0
|
||||
with_prefill_across_dp: bool = False
|
||||
|
||||
query_lens: Optional[list[int]] = None
|
||||
@@ -364,6 +365,7 @@ class AscendMLAMetadataBuilder:
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
common_prefix_len: Optional[int] = None,
|
||||
graph_pad_size: int = -1,
|
||||
max_num_tokens_across_dp: int = 0,
|
||||
with_prefill_across_dp: bool = False,
|
||||
) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
@@ -509,6 +511,7 @@ class AscendMLAMetadataBuilder:
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
max_num_tokens_across_dp=max_num_tokens_across_dp,
|
||||
with_prefill_across_dp=with_prefill_across_dp,
|
||||
)
|
||||
|
||||
|
||||
@@ -50,18 +50,10 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
# value is None, which means the system default C compiler will be used.
|
||||
"C_COMPILER":
|
||||
lambda: os.getenv("C_COMPILER", None),
|
||||
# Whether to enable MC2 for DeepSeek. If not set, the default value is False.
|
||||
# MC2 is a fusion operator provided by Ascend to speed up computing and communication.
|
||||
# Find more detail here: https://www.hiascend.com/document/detail/zh/canncommercial/81RC1/developmentguide/opdevg/ascendcbestP/atlas_ascendc_best_practices_10_0043.html
|
||||
"VLLM_ENABLE_MC2":
|
||||
lambda: bool(int(os.getenv("VLLM_ENABLE_MC2", '0'))),
|
||||
# Whether to enable the topk optimization. It's disabled by default for experimental support
|
||||
# We'll make it enabled by default in the future.
|
||||
"VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE":
|
||||
lambda: bool(int(os.getenv("VLLM_ASCEND_ENABLE_TOPK_OPTIMIZE", '0'))),
|
||||
# Whether to use LCCL communication. If not set, the default value is False.
|
||||
"USING_LCCL_COM":
|
||||
lambda: bool(int(os.getenv("USING_LCCL_COM", '0'))),
|
||||
# The version of the Ascend chip. If not set, the default value is
|
||||
# ASCEND910B1. It's used for package building. Please make sure that the
|
||||
# version is correct.
|
||||
|
||||
@@ -51,9 +51,9 @@ from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # ruff: noqa: E501
|
||||
DeepseekV2ForCausalLM # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
yarn_get_mscale # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention)
|
||||
@@ -79,7 +79,6 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
|
||||
@@ -189,26 +188,8 @@ class CustomDeepseekDBOMoE(nn.Module):
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
old_hidden_states = hidden_states.clone()
|
||||
|
||||
if self.tp_size > 1:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
|
||||
hidden_states = chunks[self.tp_rank]
|
||||
elif not self.torchair_graph_enabled:
|
||||
num_padding_tokens = (self.tp_size -
|
||||
num_tokens % self.tp_size) % self.tp_size
|
||||
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, num_padding_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@@ -220,33 +201,13 @@ class CustomDeepseekDBOMoE(nn.Module):
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.torchair_graph_enabled:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
final_hidden_states = torch.zeros(
|
||||
[num_tokens, hidden_size],
|
||||
dtype=self.params_dtype,
|
||||
device="npu")
|
||||
dist.all_gather_into_tensor(final_hidden_states,
|
||||
hidden_states, self.tp_group)
|
||||
hidden_states = final_hidden_states
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
else:
|
||||
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
||||
self.tp_group)
|
||||
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = hidden_states[:-num_padding_tokens]
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
hidden_states = hidden_states + shared_output
|
||||
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
return hidden_states
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
def _forward_ms_op_shared_expert(
|
||||
|
||||
@@ -28,7 +28,6 @@
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
@@ -37,7 +36,7 @@ from vllm.attention import Attention, AttentionMetadata
|
||||
from vllm.config import CacheConfig, ModelConfig, VllmConfig
|
||||
from vllm.distributed import (get_pp_group,
|
||||
get_tensor_model_parallel_world_size,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
get_tp_group)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
@@ -54,9 +53,9 @@ from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # ruff: noqa: E501
|
||||
DeepseekV2ForCausalLM # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
yarn_get_mscale # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention)
|
||||
@@ -65,7 +64,6 @@ from vllm.model_executor.models.utils import (
|
||||
maybe_prefix)
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
@@ -74,8 +72,6 @@ from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
|
||||
@@ -240,9 +236,8 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
@@ -312,22 +307,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
enable_force_load_balance = False
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
old_hidden_states = hidden_states
|
||||
use_separated_shared_experts = (self.shared_experts is not None
|
||||
and not self.enable_multistream_moe)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
and not is_prefill) or not (self.torchair_graph_enabled or
|
||||
self.ep_group.world_size == 1):
|
||||
if num_tokens < self.tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
@@ -338,34 +317,14 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
shared_experts=(self.shared_experts
|
||||
if not use_separated_shared_experts else None),
|
||||
shared_experts=self.shared_experts,
|
||||
)
|
||||
|
||||
if not isinstance(experts_hidden_states, tuple):
|
||||
hidden_states = experts_hidden_states * self.routed_scaling_factor
|
||||
else:
|
||||
hidden_states = (
|
||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||
experts_hidden_states[1])
|
||||
hidden_states = (
|
||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||
experts_hidden_states[1])
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
and not is_prefill) or not (self.torchair_graph_enabled or
|
||||
self.ep_group.world_size == 1):
|
||||
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
||||
self.tp_group)
|
||||
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < self.tp_size:
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
if use_separated_shared_experts:
|
||||
hidden_states = hidden_states + self.shared_experts(
|
||||
old_hidden_states)
|
||||
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
@@ -21,11 +21,13 @@ from typing import Any, Callable, List, Optional, Tuple, Union
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from torch import nn
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (GroupCoordinator,
|
||||
from vllm.distributed import (GroupCoordinator, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.distributed.parallel_state import get_dp_group, get_tp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, FusedMoEParallelConfig, MoEConfig, UnquantizedFusedMoEMethod,
|
||||
determine_expert_map)
|
||||
@@ -36,10 +38,10 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
|
||||
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
|
||||
from vllm_ascend.utils import npu_stream_switch, npu_wait_tensor
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_fused_moe_state, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
USING_LCCL_COM: bool = envs_ascend.USING_LCCL_COM
|
||||
MOE_ALL2ALL_BUFFER: bool = envs_ascend.MOE_ALL2ALL_BUFFER
|
||||
|
||||
|
||||
@@ -845,8 +847,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
super().__init__(moe=moe)
|
||||
vllm_config = get_current_vllm_config()
|
||||
|
||||
ep_group = get_ep_group()
|
||||
self.ep_size = ep_group.world_size
|
||||
self.ep_group = get_ep_group()
|
||||
self.ep_size = self.ep_group.world_size
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
@@ -855,7 +857,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
device_group = self.ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
@@ -931,7 +933,9 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -942,7 +946,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
shared_experts=shared_experts)
|
||||
elif self.torchair_graph_enabled or get_ep_group().world_size == 1:
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -1022,9 +1026,6 @@ class AscendFusedMoE(FusedMoE):
|
||||
get_dp_group().world_size),
|
||||
vllm_parallel_config=vllm_config.parallel_config))
|
||||
|
||||
self.moe_parallel_config.ep_size = get_ep_group().world_size
|
||||
self.moe_parallel_config.tp_size = get_etp_group().world_size
|
||||
|
||||
self.top_k = top_k
|
||||
self.num_experts = num_experts
|
||||
self.global_num_experts = num_experts
|
||||
@@ -1066,10 +1067,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
|
||||
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
|
||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
@@ -1109,6 +1109,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
# NOTE: self.tp_group is not expert_tp_group
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
def forward(self,
|
||||
@@ -1125,25 +1127,45 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
# MC2 ag/rs broadcast/all_reduce
|
||||
# prefill_req x x √
|
||||
# decode_req √ x √
|
||||
# graph_mode √ √ x
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif self.torchair_graph_enabled:
|
||||
if USING_LCCL_COM: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
router_logits, 0, False)
|
||||
elif self.torchair_graph_enabled and not is_prefill:
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
else:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
fused_moe_state = get_fused_moe_state(self.moe_parallel_config.ep_size,
|
||||
is_prefill)
|
||||
if shared_experts:
|
||||
if not self.enable_multistream_moe or fused_moe_state != FusedMoEState.MC2:
|
||||
shared_hidden_states = shared_experts(hidden_states)
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
if num_tokens < tp_size:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, tp_size - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits, (0, 0, 0, tp_size - num_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
tp_size,
|
||||
dim=0)
|
||||
chunk_router_logits = torch.tensor_split(router_logits,
|
||||
tp_size,
|
||||
dim=0)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
hidden_states = chunk_hidden_states[tp_rank]
|
||||
router_logits = chunk_router_logits[tp_rank]
|
||||
if self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
# NOTE: When in torchair graph, it has been padded in model_runner_v1
|
||||
if not self.torchair_graph_enabled or is_prefill:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
max_num_tokens_across_dp = attn_metadata.max_num_tokens_across_dp
|
||||
if num_tokens < max_num_tokens_across_dp:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states,
|
||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
||||
router_logits = nn.functional.pad(
|
||||
router_logits,
|
||||
(0, 0, 0, max_num_tokens_across_dp - num_tokens))
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
|
||||
# Matrix multiply.
|
||||
e_hidden_states = self.quant_method.apply(
|
||||
@@ -1167,36 +1189,36 @@ class AscendFusedMoE(FusedMoE):
|
||||
shared_experts=shared_experts,
|
||||
)
|
||||
|
||||
if shared_experts is not None:
|
||||
# Provide dummy implementation of "non-separated" shared experts.
|
||||
if not isinstance(e_hidden_states, tuple):
|
||||
return e_hidden_states, shared_experts(hidden_states)
|
||||
else:
|
||||
return e_hidden_states
|
||||
if shared_experts:
|
||||
if isinstance(e_hidden_states, tuple):
|
||||
e_hidden_states, shared_hidden_states = e_hidden_states
|
||||
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif self.torchair_graph_enabled:
|
||||
if USING_LCCL_COM: # type: ignore
|
||||
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
elif self.torchair_graph_enabled and not is_prefill:
|
||||
e_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
else:
|
||||
e_hidden_states = get_ep_group().combine(e_hidden_states)
|
||||
if tp_size > 1 and fused_moe_state != FusedMoEState.AllGather:
|
||||
dist.all_gather(list(chunk_hidden_states), e_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < tp_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
elif self.dp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
e_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
dispose_tensor(e_hidden_states)
|
||||
else:
|
||||
final_hidden_states = e_hidden_states
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
e_hidden_states = tensor_model_parallel_all_reduce(e_hidden_states)
|
||||
if tp_size > 1 and fused_moe_state == FusedMoEState.AllGather:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
|
||||
return e_hidden_states
|
||||
if shared_experts:
|
||||
return final_hidden_states, shared_hidden_states
|
||||
else:
|
||||
return final_hidden_states
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
|
||||
|
||||
@@ -22,15 +22,13 @@ import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import GroupCoordinator
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
from vllm_ascend.utils import (FusedMoEState, dispose_tensor,
|
||||
get_fused_moe_state, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
def apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
@@ -660,7 +658,9 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
|
||||
topk_weights = topk_weights.to(x.dtype)
|
||||
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
fused_moe_state = get_fused_moe_state(self.ep_group.world_size,
|
||||
is_prefill)
|
||||
if fused_moe_state == FusedMoEState.MC2:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -675,7 +675,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
shared_experts=shared_experts)
|
||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||
elif fused_moe_state == FusedMoEState.AllGather:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
|
||||
@@ -20,6 +20,7 @@
|
||||
import atexit
|
||||
import math
|
||||
from contextlib import contextmanager, nullcontext
|
||||
from enum import Enum
|
||||
from threading import Lock
|
||||
from typing import TYPE_CHECKING, List, Tuple
|
||||
|
||||
@@ -275,3 +276,21 @@ def npu_wait_tensor(self: torch.Tensor,
|
||||
*,
|
||||
enabled: bool = True):
|
||||
return _npu_wait_tensor(self, dependency) if enabled else self
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): move this into forward_context
|
||||
class FusedMoEState(Enum):
|
||||
AllGather = 0
|
||||
All2All = 1
|
||||
MC2 = 2
|
||||
|
||||
|
||||
# TODO(zzzzwwjj): add soc_version to choose branch
|
||||
def get_fused_moe_state(ep_size: int, with_prefill: bool):
|
||||
if ep_size == 1:
|
||||
return FusedMoEState.AllGather
|
||||
# NOTE: mc2 need ep_size >= 16 & all2all can't use in torchair graph.
|
||||
elif ep_size < 16 or with_prefill:
|
||||
return FusedMoEState.All2All
|
||||
else:
|
||||
return FusedMoEState.MC2
|
||||
|
||||
@@ -348,15 +348,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.init_torchair_graph_batch_sizes()
|
||||
|
||||
if len(self.torchair_graph_batch_sizes) == 0:
|
||||
#If MC2 is enabled, torchair_graph_batch_size should pad to tp_size
|
||||
if envs_ascend.VLLM_ENABLE_MC2:
|
||||
self.torchair_graph_batch_sizes = [
|
||||
self.scheduler_config.max_num_seqs
|
||||
]
|
||||
else:
|
||||
self.torchair_graph_batch_sizes = [
|
||||
1, self.scheduler_config.max_num_seqs
|
||||
]
|
||||
# TODO(zzzzwwjj): check torchair_graph_batch_sizes init code
|
||||
self.torchair_graph_batch_sizes = [
|
||||
self.scheduler_config.max_num_seqs
|
||||
]
|
||||
|
||||
torch._dynamo.cache_size.config.cache_size_limit += len(
|
||||
self.torchair_graph_batch_sizes)
|
||||
@@ -569,10 +564,12 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self, batch_size: int, with_prefill: bool) -> tuple[int, bool]:
|
||||
forward_metadata = torch.tensor([batch_size, with_prefill],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
self, total_num_scheduled_tokens: int,
|
||||
with_prefill: bool) -> tuple[int, bool]:
|
||||
forward_metadata = torch.tensor(
|
||||
[total_num_scheduled_tokens, with_prefill],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata,
|
||||
op=ReduceOp.MAX,
|
||||
group=get_dp_group().cpu_group)
|
||||
@@ -901,11 +898,11 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
total_num_scheduled_tokens, with_prefill)
|
||||
extra_builder_kwargs['max_num_tokens_across_dp'] = max_num_tokens
|
||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||
|
||||
# Add graph_pad_size here
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.torchair_graph_enabled
|
||||
and not with_prefill):
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
@@ -984,8 +981,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
else:
|
||||
positions = self.positions[:num_input_tokens]
|
||||
|
||||
if (envs_ascend.VLLM_ENABLE_MC2
|
||||
or self.torchair_graph_enabled) and not with_prefill:
|
||||
if self.torchair_graph_enabled and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
|
||||
@@ -1885,20 +1881,15 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
return spec_token_ids
|
||||
|
||||
def init_torchair_graph_batch_sizes(self):
|
||||
start_graph_batch_size = 4
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
batch_size_step = 8
|
||||
largest_batch_size = 1
|
||||
|
||||
if envs_ascend.VLLM_ENABLE_MC2:
|
||||
batch_size_step = max(batch_size_step, tp_size)
|
||||
largest_batch_size = batch_size_step
|
||||
while (largest_batch_size < 8):
|
||||
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||
largest_batch_size *= 2
|
||||
# NOTE: When use all2all | mc2, We need to slice the `num_tokens` dimension into `tp_size` blocks
|
||||
start_graph_batch_size = max(start_graph_batch_size, tp_size)
|
||||
|
||||
while (largest_batch_size <= self.scheduler_config.max_num_seqs):
|
||||
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||
largest_batch_size += batch_size_step
|
||||
while (start_graph_batch_size <= self.scheduler_config.max_num_seqs):
|
||||
self.torchair_graph_batch_sizes.append(start_graph_batch_size)
|
||||
start_graph_batch_size *= 2
|
||||
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
selected_batch_size = self.max_num_reqs
|
||||
|
||||
@@ -38,7 +38,6 @@ from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
|
||||
from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import init_ascend_config
|
||||
from vllm_ascend.device_allocator.camem import CaMemAllocator
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
@@ -247,15 +246,15 @@ class NPUWorker(WorkerBase):
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
runner = self.model_runner
|
||||
num_tokens = 1
|
||||
max_num_tokens = 1
|
||||
with_prefill = False
|
||||
if runner.dp_size > 1:
|
||||
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
||||
1, False)
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or runner.torchair_graph_enabled:
|
||||
if not with_prefill:
|
||||
num_tokens = max_num_tokens
|
||||
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
|
||||
runner._dummy_run(num_tokens,
|
||||
max_num_tokens, with_prefill)
|
||||
if runner.torchair_graph_enabled and not with_prefill:
|
||||
max_num_tokens = runner.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
runner._dummy_run(max_num_tokens,
|
||||
is_compile=False,
|
||||
with_prefill=with_prefill)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user