[bugfix] fix deeepseek accuracy (#1118)

### What this PR does / why we need it?
fix deeepseek accuracy in mix-parallel case.


Signed-off-by: zzzzwwjj <1183291235@qq.com>
This commit is contained in:
zzzzwwjj
2025-06-07 21:11:36 +08:00
committed by GitHub
parent c8742146d3
commit f1543d5e0d
3 changed files with 23 additions and 27 deletions

View File

@@ -67,6 +67,7 @@ from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.parallel_state import get_ep_group
from vllm_ascend.ops.fused_moe import AscendFusedMoE
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
from vllm_ascend.utils import dispose_tensor
@@ -211,13 +212,15 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.ep_group = get_ep_group()
self.params_dtype = torch.get_default_dtype()
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
def forward(
self,
@@ -245,16 +248,12 @@ class CustomDeepseekV2MoE(nn.Module):
old_hidden_states = hidden_states.clone()
if self.tp_size > 1:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.torchair_graph_enabled:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
if num_padding_tokens > 0:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
if num_tokens < self.tp_size:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, num_padding_tokens))
hidden_states, (0, 0, 0, self.tp_size - num_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
@@ -284,24 +283,16 @@ class CustomDeepseekV2MoE(nn.Module):
hidden_states = hidden_states * self.routed_scaling_factor
if self.tp_size > 1:
if self.torchair_graph_enabled:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states,
hidden_states, self.tp_group)
hidden_states = final_hidden_states
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
else:
if (VLLM_ENABLE_MC2
and not is_prefill) or not (self.torchair_graph_enabled or
self.ep_group.world_size == 1):
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_padding_tokens > 0:
hidden_states = hidden_states[:-num_padding_tokens]
if num_tokens < self.tp_size:
hidden_states = hidden_states[:num_tokens]
else:
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
if self.n_shared_experts is not None:
if not multistream:

View File

@@ -1027,8 +1027,9 @@ class AscendFusedMoE(FusedMoE):
ascend_config = get_ascend_config()
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
self.enable_multistream_shared_expert = \
ascend_config.torchair_graph_config.enable_multistream_shared_expert
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "

View File

@@ -142,7 +142,11 @@ class NPUPlatform(Platform):
# NOTE: When enable_expert_parallel is True, we follow vLLM convention:
# ep_size = world_size, which means expert_tensor_parallel_size must be 1
if ascend_config.expert_tensor_parallel_size > 0 and not parallel_config.enable_expert_parallel:
if parallel_config.enable_expert_parallel:
parallel_config.expert_tensor_parallel_size = 1
# NOTE: When enable_expert_parallel is False and param `asceend_config.expert_tensor_parallel_size`
# is configured, use ascend_config
elif ascend_config.expert_tensor_parallel_size > 0:
parallel_config.expert_tensor_parallel_size = ascend_config.expert_tensor_parallel_size
# Calculate expert parallel size based on world size