[refactor] Refactoring AscendFusedMoE (#1229)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? This PR is used for resolved [issue 1147](https://github.com/vllm-project/vllm-ascend/issues/1147) 1. Move fused_moe code into one file `fused_moe.py`. 2. Integrate branch conditions into function `get_fused_moe_state`. <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? 1. This PR has removed the env `VLLM_ENABLE_MC2`, because I think this env is useless, we can make judgments based on the current scenario without this env, it will only increase complexity. 2. This PR has removed the env `USING_LCCL_COM`, because this env has already expired. 3. `additional_config.expert_tensor_parallel_size` has already expired, and now we also use parameter `enable_expert_parallel`, consistent with the vLLM. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
@@ -51,9 +51,9 @@ from vllm.model_executor.layers.sampler import get_sampler
|
||||
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
||||
ParallelLMHead, VocabParallelEmbedding)
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
DeepseekV2ForCausalLM # ruff: noqa: E501
|
||||
DeepseekV2ForCausalLM # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import \
|
||||
yarn_get_mscale # ruff: noqa: E501
|
||||
yarn_get_mscale # noqa: E501
|
||||
from vllm.model_executor.models.deepseek_v2 import (DeepseekV2Attention,
|
||||
DeepseekV2DecoderLayer,
|
||||
DeepseekV2MLAAttention)
|
||||
@@ -79,7 +79,6 @@ from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
|
||||
@@ -189,26 +188,8 @@ class CustomDeepseekDBOMoE(nn.Module):
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
old_hidden_states = hidden_states.clone()
|
||||
|
||||
if self.tp_size > 1:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
|
||||
hidden_states = chunks[self.tp_rank]
|
||||
elif not self.torchair_graph_enabled:
|
||||
num_padding_tokens = (self.tp_size -
|
||||
num_tokens % self.tp_size) % self.tp_size
|
||||
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, num_padding_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
@@ -220,33 +201,13 @@ class CustomDeepseekDBOMoE(nn.Module):
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if self.torchair_graph_enabled:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
final_hidden_states = torch.zeros(
|
||||
[num_tokens, hidden_size],
|
||||
dtype=self.params_dtype,
|
||||
device="npu")
|
||||
dist.all_gather_into_tensor(final_hidden_states,
|
||||
hidden_states, self.tp_group)
|
||||
hidden_states = final_hidden_states
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
else:
|
||||
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
||||
self.tp_group)
|
||||
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = hidden_states[:-num_padding_tokens]
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
hidden_states = hidden_states + shared_output
|
||||
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
return hidden_states
|
||||
|
||||
# ----------------------------------------- TBO-related --------------------------------------------
|
||||
def _forward_ms_op_shared_expert(
|
||||
|
||||
Reference in New Issue
Block a user