[BugFix] Fix accuracy bugs for unquantized deepseekv3 models (#897)
### What this PR does / why we need it? This PR fixes two accuracy bugs incurred by PR #819 when running deepseekv3 series models: 1. #819 adds `all_to_all` communication in quantized cases, but `all_gather` && `reduce_scatter` are removed in both of quantized and unquantized cases. When running unquantized deepseekv3 models with `ep_size == world_size`, the moe modules fail to communicate. Therefore, this PR adds `all_to_all` communication on unquantized situation to solve this accuracy issue. 2. Use `ep_size` rather than `dp_size` to decide whether to use `all_to_all` in moe. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? CI passed with new added/existing test. --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -18,9 +18,11 @@
|
|||||||
from typing import Callable, Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch_npu
|
import torch_npu
|
||||||
from vllm.config import get_current_vllm_config
|
from vllm.config import get_current_vllm_config
|
||||||
from vllm.distributed import (get_tensor_model_parallel_world_size,
|
from vllm.distributed import (GroupCoordinator,
|
||||||
|
get_tensor_model_parallel_world_size,
|
||||||
tensor_model_parallel_all_reduce)
|
tensor_model_parallel_all_reduce)
|
||||||
from vllm.distributed.parallel_state import get_dp_group
|
from vllm.distributed.parallel_state import get_dp_group
|
||||||
from vllm.model_executor.layers.fused_moe.layer import (
|
from vllm.model_executor.layers.fused_moe.layer import (
|
||||||
@@ -154,6 +156,143 @@ def fused_experts_with_mc2(
|
|||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
# currently expert parallelism implemented with all2all
|
||||||
|
# is under-optimized.
|
||||||
|
def fused_experts_with_all2all(
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
w1: torch.Tensor,
|
||||||
|
w2: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
top_k: int,
|
||||||
|
expert_map: torch.Tensor = None,
|
||||||
|
ep_group: GroupCoordinator = None,
|
||||||
|
):
|
||||||
|
original_shape = hidden_states.shape
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||||
|
|
||||||
|
num_tokens, _ = hidden_states.shape
|
||||||
|
num_experts = w1.shape[0]
|
||||||
|
device = hidden_states.device
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
global_num_experts = len(expert_map)
|
||||||
|
local_num_experts = global_num_experts // ep_group.world_size
|
||||||
|
row_idx_len = num_tokens * top_k
|
||||||
|
row_idx = (torch.arange(0,
|
||||||
|
row_idx_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=device).view(top_k, -1).permute(
|
||||||
|
1, 0).contiguous())
|
||||||
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||||
|
hidden_states,
|
||||||
|
row_idx=row_idx,
|
||||||
|
expert_idx=topk_ids,
|
||||||
|
active_num=num_tokens)
|
||||||
|
|
||||||
|
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
||||||
|
minlength=global_num_experts)
|
||||||
|
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
||||||
|
-1).sum(-1)
|
||||||
|
|
||||||
|
gather_sizes = torch.empty_like(scatter_sizes)
|
||||||
|
dist.all_to_all_single(gather_sizes,
|
||||||
|
scatter_sizes,
|
||||||
|
group=ep_group.device_group)
|
||||||
|
scatter_size_list = scatter_sizes.cpu().tolist()
|
||||||
|
gather_size_list = gather_sizes.cpu().tolist()
|
||||||
|
|
||||||
|
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
||||||
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||||
|
scatter_size_list,
|
||||||
|
gather_size_list)
|
||||||
|
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
||||||
|
scatter_size_list,
|
||||||
|
gather_size_list)
|
||||||
|
|
||||||
|
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
||||||
|
|
||||||
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||||
|
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||||
|
|
||||||
|
hidden_states = hidden_states[sorted_idx]
|
||||||
|
else:
|
||||||
|
row_idx_len = num_tokens * top_k
|
||||||
|
row_idx = torch.arange(0,
|
||||||
|
row_idx_len,
|
||||||
|
dtype=torch.int32,
|
||||||
|
device=topk_weights.device).view(
|
||||||
|
top_k, -1).permute(1, 0).contiguous()
|
||||||
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||||
|
hidden_states,
|
||||||
|
row_idx=row_idx,
|
||||||
|
expert_idx=topk_ids,
|
||||||
|
active_num=num_tokens)
|
||||||
|
|
||||||
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||||
|
expanded_expert_idx, num_experts)
|
||||||
|
expert_tokens = expert_tokens.to(torch.int64)
|
||||||
|
|
||||||
|
w1 = w1.transpose(1, 2)
|
||||||
|
gate_up_out_list = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[w1],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=0,
|
||||||
|
group_type=0,
|
||||||
|
group_list=expert_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO: Remove this in the future.
|
||||||
|
hidden_states = torch.cat(gate_up_out_list, dim=0)
|
||||||
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||||
|
|
||||||
|
w2 = w2.transpose(1, 2)
|
||||||
|
down_out_list = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[hidden_states],
|
||||||
|
weight=[w2],
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=0,
|
||||||
|
group_type=0,
|
||||||
|
group_list=expert_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
hidden_states = torch.cat(down_out_list, dim=0)
|
||||||
|
|
||||||
|
if expert_map is not None:
|
||||||
|
resorted_idx = torch.argsort(sorted_idx)
|
||||||
|
hidden_states = hidden_states[resorted_idx]
|
||||||
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||||
|
gather_size_list,
|
||||||
|
scatter_size_list)
|
||||||
|
|
||||||
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||||
|
hidden_states,
|
||||||
|
skip1=None,
|
||||||
|
skip2=None,
|
||||||
|
bias=None,
|
||||||
|
scales=topk_weights,
|
||||||
|
expanded_src_to_dst_row=expanded_row_idx,
|
||||||
|
export_for_source_row=topk_ids,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# TODO: Reorder device memory 2 times here, replace the current
|
||||||
|
# implementation here when suitable operators become available.
|
||||||
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||||
|
hidden_states,
|
||||||
|
skip1=None,
|
||||||
|
skip2=None,
|
||||||
|
bias=None,
|
||||||
|
scales=topk_weights,
|
||||||
|
expanded_src_to_dst_row=expanded_row_idx,
|
||||||
|
export_for_source_row=topk_ids,
|
||||||
|
)
|
||||||
|
if len(original_shape) == 3:
|
||||||
|
final_hidden_states = final_hidden_states.view(original_shape)
|
||||||
|
return final_hidden_states
|
||||||
|
|
||||||
|
|
||||||
def fused_experts(
|
def fused_experts(
|
||||||
hidden_states: torch.Tensor,
|
hidden_states: torch.Tensor,
|
||||||
w1: torch.Tensor,
|
w1: torch.Tensor,
|
||||||
@@ -494,7 +633,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
custom_routing_function: Optional[Callable] = None,
|
custom_routing_function: Optional[Callable] = None,
|
||||||
scoring_func: str = "softmax",
|
scoring_func: str = "softmax",
|
||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
is_prefill=False,
|
is_prefill: bool = False,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
||||||
@@ -536,7 +675,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||||
else:
|
elif get_ep_group().world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w2=layer.w2_weight,
|
w2=layer.w2_weight,
|
||||||
@@ -544,6 +683,19 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
|
else:
|
||||||
|
# The current implementation of deepseek moe splits hidden_states
|
||||||
|
# according to tp_size before they are feed into fused_moe module.
|
||||||
|
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||||
|
# dispatch/combine tokens.
|
||||||
|
return fused_experts_with_all2all(hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
top_k=top_k,
|
||||||
|
expert_map=expert_map,
|
||||||
|
ep_group=get_ep_group())
|
||||||
|
|
||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
class AscendFusedMoE(FusedMoE):
|
||||||
@@ -721,8 +873,7 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
scoring_func=self.scoring_func,
|
scoring_func=self.scoring_func,
|
||||||
e_score_correction_bias=self.e_score_correction_bias,
|
e_score_correction_bias=self.e_score_correction_bias,
|
||||||
is_prefill=is_prefill,
|
is_prefill=is_prefill,
|
||||||
enable_force_load_balance=enable_force_load_balance,
|
enable_force_load_balance=enable_force_load_balance)
|
||||||
dp_size=self.dp_size)
|
|
||||||
|
|
||||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||||
...
|
...
|
||||||
|
|||||||
@@ -323,14 +323,13 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
|||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
is_prefill: bool = True,
|
is_prefill: bool = True,
|
||||||
enable_force_load_balance: bool = False,
|
enable_force_load_balance: bool = False,
|
||||||
dp_size: int = 1,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return self.quant_method.apply(
|
return self.quant_method.apply(
|
||||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||||
is_prefill, enable_force_load_balance, dp_size)
|
is_prefill, enable_force_load_balance)
|
||||||
|
|
||||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||||
|
|||||||
@@ -582,7 +582,6 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||||
is_prefill: bool = True,
|
is_prefill: bool = True,
|
||||||
enable_force_load_balance: bool = True,
|
enable_force_load_balance: bool = True,
|
||||||
dp_size: int = 1,
|
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
assert router_logits.shape[
|
assert router_logits.shape[
|
||||||
@@ -635,7 +634,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||||
elif dp_size == 1:
|
elif self.ep_group.world_size == 1:
|
||||||
return fused_experts(hidden_states=x,
|
return fused_experts(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
@@ -646,6 +645,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
expert_map=expert_map)
|
expert_map=expert_map)
|
||||||
else:
|
else:
|
||||||
|
# The current implementation of deepseek moe splits hidden_states
|
||||||
|
# according to tp_size before they are feed into fused_moe module.
|
||||||
|
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
||||||
|
# dispatch/combine tokens.
|
||||||
return fused_experts_with_all2all(hidden_states=x,
|
return fused_experts_with_all2all(hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
w1_scale=layer.w13_weight_scale,
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
|||||||
Reference in New Issue
Block a user