[BugFix]add all2all when dp_size > 1 && downgrade npu_dequant_swiglu_quant (#819)
<!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? 1. This PR introduces native `all_to_all` communication operator to fix `allgather` bugs when dp_size > 1. Besides, it adds a naive implementation of force-load-balance when doing profile runs. 2. The operator `npu_dequant_swiglu_quant` only supports input hidden_states with dtype `torch.int32`. This tensor occupies space of `global_bs * seq_len * topk * hidden_size`, which might be very large as `ep_size` grows. Therefore we need to disable this operator and use original `swiglu` && `quantize`. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By performing offline inference:  --------- Signed-off-by: angazenn <zengyanjia@huawei.com> Co-authored-by: angazenn <zengyanjia@huawei.com>
This commit is contained in:
@@ -14,10 +14,10 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
from torch.distributed import ProcessGroup
|
||||
import torch.distributed as dist
|
||||
from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
DeviceCommunicatorBase
|
||||
|
||||
@@ -25,11 +25,51 @@ from vllm.distributed.device_communicators.base_device_communicator import \
|
||||
class NPUCommunicator(DeviceCommunicatorBase):
|
||||
|
||||
def __init__(self,
|
||||
cpu_group: ProcessGroup,
|
||||
cpu_group: dist.ProcessGroup,
|
||||
device: Optional[torch.device] = None,
|
||||
device_group: Optional[ProcessGroup] = None,
|
||||
device_group: Optional[dist.ProcessGroup] = None,
|
||||
unique_name: str = ""):
|
||||
super().__init__(cpu_group, device, device_group, unique_name)
|
||||
# TODO(hz): Refer to CudaCommunicator's implementation to integrate PyHcclCommunicator
|
||||
# init device according to rank
|
||||
self.device = torch.npu.current_device()
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
|
||||
if scatter_dim < 0:
|
||||
scatter_dim += input_.dim()
|
||||
if gather_dim < 0:
|
||||
gather_dim += input_.dim()
|
||||
|
||||
if scatter_sizes is not None and gather_sizes is not None:
|
||||
input_list = [
|
||||
t.contiguous()
|
||||
for t in torch.split(input_, scatter_sizes, scatter_dim)
|
||||
]
|
||||
output_list = []
|
||||
tensor_shape_base = input_list[self.rank].size()
|
||||
for i in range(self.world_size):
|
||||
tensor_shape = list(tensor_shape_base)
|
||||
tensor_shape[gather_dim] = gather_sizes[i]
|
||||
output_list.append(
|
||||
torch.empty(tensor_shape,
|
||||
dtype=input_.dtype,
|
||||
device=input_.device))
|
||||
|
||||
else:
|
||||
input_list = [
|
||||
t.contiguous() for t in torch.tensor_split(
|
||||
input_, self.world_size, scatter_dim)
|
||||
]
|
||||
output_list = [
|
||||
torch.empty_like(input_list[i]) for i in range(self.world_size)
|
||||
]
|
||||
|
||||
dist.all_to_all(output_list, input_list, group=self.device_group)
|
||||
output_tensor = torch.cat(output_list, dim=gather_dim).contiguous()
|
||||
return output_tensor
|
||||
|
||||
@@ -205,50 +205,66 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
)
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.dp_size = get_dp_group().world_size
|
||||
batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
|
||||
params_dtype = torch.get_default_dtype()
|
||||
self.final_hidden_states = torch.zeros(
|
||||
[batch_size, config.hidden_size], dtype=params_dtype, device="npu")
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
# when profile runs, force experts to load balanced tokens
|
||||
# to avoid high memory consumption on a single rank.
|
||||
# TODO: need a better flag to indicate whether in profile run or not.
|
||||
if attn_metadata is None:
|
||||
# for profile run
|
||||
is_prefill = True
|
||||
enable_force_load_balance = True
|
||||
else:
|
||||
is_prefill = attn_metadata.num_prefills > 0
|
||||
enable_force_load_balance = False
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
hidden_states = hidden_states.view(-1, hidden_dim)
|
||||
|
||||
if (self.tp_size > 1 and VLLM_ENABLE_MC2 and not is_prefill):
|
||||
chunks = torch.chunk(hidden_states,
|
||||
get_tp_group().world_size,
|
||||
dim=0)
|
||||
hidden_states = chunks[get_tp_group().rank_in_group]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
final_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
dist.all_gather_into_tensor(self.final_hidden_states,
|
||||
final_hidden_states, self.tp_group)
|
||||
final_hidden_states = self.final_hidden_states
|
||||
else:
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if self.tp_size > 1:
|
||||
# pass
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
if num_tokens < self.tp_size:
|
||||
target_size = self.tp_size
|
||||
new_hidden_states = torch.empty([target_size, hidden_size],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
new_hidden_states[:num_tokens] = hidden_states
|
||||
hidden_states = new_hidden_states
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
local_hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
else:
|
||||
local_hidden_states = hidden_states
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(local_hidden_states)
|
||||
|
||||
router_hidden_states = self.experts(
|
||||
hidden_states=local_hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < self.tp_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
else:
|
||||
final_hidden_states = router_hidden_states
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
|
||||
@@ -18,7 +18,6 @@
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import tensor_model_parallel_all_reduce
|
||||
@@ -636,6 +635,7 @@ class AscendFusedMoE(FusedMoE):
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
is_prefill: bool,
|
||||
enable_force_load_balance: bool = False,
|
||||
top_k=None):
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -644,17 +644,8 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif USING_LCCL_COM: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
router_logits, 0, False)
|
||||
else:
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
@@ -671,17 +662,12 @@ class AscendFusedMoE(FusedMoE):
|
||||
custom_routing_function=self.custom_routing_function,
|
||||
scoring_func=self.scoring_func,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
is_prefill=is_prefill)
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
dp_size=self.dp_size)
|
||||
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
else:
|
||||
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
final_hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
|
||||
@@ -14,9 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
# patch_utils should be the first import, because it will be used by other
|
||||
# patch files.
|
||||
import vllm_ascend.patch.worker.patch_common.patch_utils # noqa isort:skip
|
||||
import vllm_ascend.patch.worker.patch_common.patch_distributed # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
|
||||
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
|
||||
|
||||
49
vllm_ascend/patch/worker/patch_common/patch_distributed.py
Normal file
49
vllm_ascend/patch/worker/patch_common/patch_distributed.py
Normal file
@@ -0,0 +1,49 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import vllm
|
||||
from vllm.distributed.parallel_state import GroupCoordinator
|
||||
|
||||
|
||||
class GroupCoordinatorPatch(GroupCoordinator):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def all_to_all(self,
|
||||
input_: torch.Tensor,
|
||||
scatter_dim: int = 0,
|
||||
gather_dim: int = -1,
|
||||
scatter_sizes: Optional[List[int]] = None,
|
||||
gather_sizes: Optional[List[int]] = None) -> torch.Tensor:
|
||||
if self.world_size == 1:
|
||||
return input_
|
||||
assert -input_.dim() <= scatter_dim < input_.dim(), (
|
||||
f"Invalid scatter dim ({scatter_dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
assert -input_.dim() <= gather_dim < input_.dim(), (
|
||||
f"Invalid gather dim ({gather_dim}) for input tensor with shape {input_.size()}"
|
||||
)
|
||||
return self.device_communicator.all_to_all(input_, scatter_dim,
|
||||
gather_dim, scatter_sizes,
|
||||
gather_sizes)
|
||||
|
||||
|
||||
vllm.distributed.parallel_state.GroupCoordinator = GroupCoordinatorPatch # Note: check the GroupCoordinator with online serving
|
||||
@@ -321,14 +321,15 @@ class AscendFusedMoEMethod(FusedMoEMethodBase):
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = False,
|
||||
dp_size: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
return self.quant_method.apply(layer, x, router_logits, top_k,
|
||||
renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map,
|
||||
topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func,
|
||||
e_score_correction_bias, is_prefill)
|
||||
return self.quant_method.apply(
|
||||
layer, x, router_logits, top_k, renormalize, use_grouped_topk,
|
||||
global_num_experts, expert_map, topk_group, num_expert_group,
|
||||
custom_routing_function, scoring_func, e_score_correction_bias,
|
||||
is_prefill, enable_force_load_balance, dp_size)
|
||||
|
||||
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
||||
if hasattr(self.quant_method, "process_weights_after_loading"):
|
||||
|
||||
@@ -15,15 +15,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import os
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
from vllm.distributed import GroupCoordinator
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
w1: torch.Tensor,
|
||||
@@ -68,24 +72,18 @@ def apply_mlp(hidden_states_wrapper: List[torch.Tensor],
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
weight=[w1],
|
||||
split_item=3,
|
||||
scale=[w1_scale],
|
||||
per_token_scale=[pertoken_scale],
|
||||
split_item=2,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
|
||||
# act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=hidden_states,
|
||||
weight_scale=w1_scale,
|
||||
activation_scale=pertoken_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=group_list,
|
||||
activate_left=True,
|
||||
quant_mode=1,
|
||||
)
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
@@ -201,6 +199,132 @@ def fused_experts_with_mc2(
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
# is under-optimized.
|
||||
def fused_experts_with_all2all(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
ep_group: GroupCoordinator = None,
|
||||
):
|
||||
original_shape = hidden_states.shape
|
||||
if len(original_shape) == 3:
|
||||
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
||||
|
||||
num_tokens, _ = hidden_states.shape
|
||||
num_experts = w1.shape[0]
|
||||
device = hidden_states.device
|
||||
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map)
|
||||
local_num_experts = global_num_experts // ep_group.world_size
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = (torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=device).view(top_k, -1).permute(
|
||||
1, 0).contiguous())
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
||||
minlength=global_num_experts)
|
||||
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
||||
-1).sum(-1)
|
||||
|
||||
gather_sizes = torch.empty_like(scatter_sizes)
|
||||
dist.all_to_all_single(gather_sizes,
|
||||
scatter_sizes,
|
||||
group=ep_group.device_group)
|
||||
scatter_size_list = scatter_sizes.cpu().tolist()
|
||||
gather_size_list = gather_sizes.cpu().tolist()
|
||||
|
||||
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
||||
scatter_size_list,
|
||||
gather_size_list)
|
||||
|
||||
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
||||
|
||||
hidden_states = hidden_states[sorted_idx]
|
||||
group_list_type = 0
|
||||
else:
|
||||
row_idx_len = num_tokens * top_k
|
||||
row_idx = torch.arange(0,
|
||||
row_idx_len,
|
||||
dtype=torch.int32,
|
||||
device=topk_weights.device).view(
|
||||
top_k, -1).permute(1, 0).contiguous()
|
||||
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
||||
hidden_states,
|
||||
row_idx=row_idx,
|
||||
expert_idx=topk_ids,
|
||||
active_num=num_tokens)
|
||||
|
||||
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
||||
expanded_expert_idx, num_experts)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 0
|
||||
|
||||
hidden_states_wrapper = [hidden_states]
|
||||
del hidden_states
|
||||
|
||||
hidden_states = apply_mlp(hidden_states_wrapper,
|
||||
w1,
|
||||
w1_scale,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_tokens,
|
||||
group_list_type=group_list_type)
|
||||
|
||||
if expert_map is not None:
|
||||
resorted_idx = torch.argsort(sorted_idx)
|
||||
hidden_states = hidden_states[resorted_idx]
|
||||
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
||||
gather_size_list,
|
||||
scatter_size_list)
|
||||
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
else:
|
||||
# TODO: Reorder device memory 2 times here, replace the current
|
||||
# implementation here when suitable operators become available.
|
||||
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
||||
hidden_states,
|
||||
skip1=None,
|
||||
skip2=None,
|
||||
bias=None,
|
||||
scales=topk_weights,
|
||||
expanded_src_to_dst_row=expanded_row_idx,
|
||||
export_for_source_row=topk_ids,
|
||||
)
|
||||
if len(original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(original_shape)
|
||||
return final_hidden_states
|
||||
|
||||
|
||||
def fused_experts(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
@@ -387,10 +511,10 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
def __init__(self):
|
||||
self.transpose_weight = True
|
||||
|
||||
ep_group = get_ep_group()
|
||||
self.ep_group = get_ep_group()
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
device_group = self.ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=device_group)
|
||||
backend = device_group._get_backend(torch.device("npu"))
|
||||
@@ -457,6 +581,8 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
is_prefill: bool = True,
|
||||
enable_force_load_balance: bool = True,
|
||||
dp_size: int = 1,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -491,7 +617,13 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if os.environ.get("VLLM_ENABLE_MC2", '0') == "1" and not is_prefill:
|
||||
# this is a naive implementation for experts load balance so as
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
||||
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
return fused_experts_with_mc2(
|
||||
hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
@@ -503,7 +635,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
else:
|
||||
elif dp_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
@@ -513,6 +645,17 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map)
|
||||
else:
|
||||
return fused_experts_with_all2all(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w1_scale=layer.w13_weight_scale,
|
||||
w2=layer.w2_weight,
|
||||
w2_scale=layer.w2_weight_scale,
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
ep_group=self.ep_group)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
@@ -521,7 +664,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
||||
1, 2).contiguous()
|
||||
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
||||
layer.w13_weight_scale.data.shape[0], -1).to(torch.float32)
|
||||
layer.w13_weight_scale.data.shape[0], -1)
|
||||
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
||||
layer.w13_weight_offset.data.shape[0], -1)
|
||||
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
||||
|
||||
Reference in New Issue
Block a user