### What this PR does / why we need it?
Add EPLB expert map import capabilities
### Does this PR introduce _any_ user-facing change?
When importing the EPLB expert map you need import expert map file by
vllm args additional_config
### How was this patch tested?
1.You need to collect expert hotness and generate an expert placement
file based on the hotness and the EPLB algorithm, or you can directly
use an existing expert placement table.
2.When launching vLLM, enable EC2 and pass the configuration via the
command-line argument:
--additional-config '{"expert_map_path": "/xxx/xxx/xx.json"}
Co-authored-by: songshanhu07 <1763685535@qq.com>
---------
Signed-off-by: songshanhu07 <1763685535@qq.com>
Signed-off-by: Yuxiao-Xu <664988918@qq.com>
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: songshanhu07 <1763685535@qq.com>
Co-authored-by: Xu Yuxiao <xuyuxiao2@huawei.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
753 lines
29 KiB
Python
753 lines
29 KiB
Python
#
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
# This file is a part of the vllm-ascend project.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
#
|
|
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch_npu
|
|
import torchair as tng # type: ignore
|
|
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
|
|
|
|
import vllm_ascend.envs as envs_ascend
|
|
from vllm_ascend.ascend_config import get_ascend_config
|
|
from vllm_ascend.distributed.parallel_state import get_ep_group
|
|
from vllm_ascend.ops.fused_moe import select_experts
|
|
from vllm_ascend.utils import dispose_tensor
|
|
|
|
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
|
|
|
|
|
def apply_mlp(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
group_list: torch.Tensor,
|
|
dynamic_scale: torch.Tensor = None,
|
|
group_list_type: int = 1,
|
|
**kwargs) -> torch.Tensor:
|
|
"""
|
|
apply MLP: gate_up_proj -> swiglu -> down_proj
|
|
|
|
Args:
|
|
hidden_states: input hidden states with shape (num_tokens, hidden_size).
|
|
w1: expert weights1 with shape
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w1_scale: weights1 scale with shape (num_experts, intermediate_size * 2)
|
|
w2: expert weights2 with shape
|
|
(num_experts, intermediate_size, hidden_size)
|
|
w2_scale: weights2 scale with shape (num_experts, hidden_size)
|
|
group_list: number of tokens for each expert, follow cumsum mode, and
|
|
with shape (num_experts).
|
|
transpose_weight:
|
|
w1: (num_experts, intermediate_size * 2, hidden_size) ->
|
|
(num_experts, hidden_size, intermediate_size * 2)
|
|
w2: (num_experts, hidden_size, intermediate_size) ->
|
|
(num_experts, intermediate_size, hidden_size)
|
|
|
|
Returns:
|
|
hidden_states: output hidden states after MLP.
|
|
"""
|
|
|
|
if dynamic_scale is None:
|
|
unquantized_hidden_states = hidden_states
|
|
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
|
hidden_states)
|
|
# Dispose the original unquantized hidden states
|
|
# to save npu memory because they're no longer used.
|
|
dispose_tensor(unquantized_hidden_states)
|
|
else:
|
|
pertoken_scale = dynamic_scale
|
|
|
|
shared_experts = kwargs.get('shared_experts', None)
|
|
if shared_experts:
|
|
shared_gate_up = kwargs.get('shared_gate_up', None)
|
|
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
|
|
with tng.scope.npu_stream_switch('cv'):
|
|
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
|
|
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
|
x=shared_gate_up,
|
|
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
|
|
activation_scale=shared_dynamic_scale,
|
|
bias=None,
|
|
quant_scale=None,
|
|
quant_offset=None,
|
|
group_index=None,
|
|
activate_left=True,
|
|
quant_mode=1)
|
|
|
|
# gmm1: gate_up_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w1],
|
|
scale=[w1_scale],
|
|
per_token_scale=[pertoken_scale],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=w2_scale.dtype)[0]
|
|
|
|
# act_fn: swiglu
|
|
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
|
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
|
hidden_states)
|
|
|
|
# gmm2: down_proj
|
|
hidden_states = torch_npu.npu_grouped_matmul(
|
|
x=[hidden_states],
|
|
weight=[w2],
|
|
scale=[w2_scale],
|
|
per_token_scale=[swiglu_out_scale],
|
|
split_item=2,
|
|
group_list_type=group_list_type,
|
|
group_type=0,
|
|
group_list=group_list,
|
|
output_dtype=w2_scale.dtype)[0]
|
|
|
|
if shared_experts:
|
|
with tng.scope.npu_stream_switch('cv'):
|
|
tng.scope.npu_wait_tensor(shared_x, hidden_states)
|
|
shared_output = torch_npu.npu_quant_matmul(
|
|
shared_x,
|
|
shared_experts.down_proj.weight,
|
|
shared_experts.down_proj.weight_scale,
|
|
pertoken_scale=shared_dynamic_scale,
|
|
output_dtype=torch.bfloat16,
|
|
)
|
|
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
|
|
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
|
if shared_experts:
|
|
return hidden_states, shared_output
|
|
return hidden_states
|
|
|
|
|
|
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
top_k: int,
|
|
expert_map: torch.Tensor = None,
|
|
moe_all_to_all_group_name: str = "",
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
**kwargs) -> torch.Tensor:
|
|
|
|
topk_ids = log2phy[topk_ids]
|
|
global_bs = 0
|
|
moe_expert_num = len(expert_map) + global_redundant_expert_num
|
|
# hidden_states = hidden_states.bfloat16()
|
|
kwargs_mc2 = {
|
|
"x": hidden_states,
|
|
"expert_ids": topk_ids,
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": moe_expert_num,
|
|
"global_bs": global_bs,
|
|
}
|
|
|
|
rank = torch.distributed.get_rank()
|
|
|
|
quant_mode = 2
|
|
ep_group = get_ep_group().device_group
|
|
local_rank = torch.distributed.get_rank(group=ep_group)
|
|
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
|
|
|
|
world_szie = torch.distributed.get_world_size()
|
|
tp_size = world_szie // all_to_all_group_size
|
|
tp_rank = rank % tp_size
|
|
|
|
stage1_kwargs = {
|
|
"scales": None,
|
|
"quant_mode": quant_mode,
|
|
"group_ep": moe_all_to_all_group_name,
|
|
"ep_world_size": all_to_all_group_size,
|
|
"ep_rank_id": local_rank,
|
|
# "group_tp": self.moe_rs_group_name,
|
|
"group_tp": moe_all_to_all_group_name,
|
|
"tp_world_size": tp_size,
|
|
"tp_rank_id": tp_rank,
|
|
}
|
|
kwargs_mc2.update(stage1_kwargs)
|
|
|
|
shared_experts = kwargs.get('shared_experts', None)
|
|
if shared_experts:
|
|
shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
|
with tng.scope.npu_stream_switch('cv'):
|
|
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
|
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
|
|
shared_hidden_states)
|
|
shared_gate_up = torch_npu.npu_quant_matmul(
|
|
shared_x,
|
|
shared_experts.gate_up_proj.weight,
|
|
shared_experts.gate_up_proj.weight_scale,
|
|
output_dtype=torch.int32,
|
|
)
|
|
kwargs.update({
|
|
"shared_gate_up": shared_gate_up,
|
|
"shared_dynamic_scale": shared_dynamic_scale,
|
|
})
|
|
|
|
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
|
# comm_stream.wait_stream(torch.npu.current_stream())
|
|
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
|
0:5]
|
|
|
|
if quant_mode == 0:
|
|
dynamic_scale = None
|
|
|
|
# `expand_x` will be disposed in the `apply_mlp` function
|
|
down_out_list = apply_mlp(expand_x,
|
|
w1,
|
|
w1_scale,
|
|
w2,
|
|
w2_scale,
|
|
expert_token_nums,
|
|
dynamic_scale=dynamic_scale,
|
|
**kwargs)
|
|
|
|
multi_stream = isinstance(down_out_list, tuple)
|
|
if multi_stream:
|
|
down_out_list, shared_output = down_out_list
|
|
|
|
# moeCombine
|
|
kwargs_mc2 = {
|
|
"expand_x": down_out_list,
|
|
"expert_ids": topk_ids,
|
|
"expand_idx": expand_idx,
|
|
"expert_scales": topk_weights.to(torch.float32),
|
|
"expert_shard_type": 0,
|
|
"shared_expert_rank_num": 0,
|
|
"moe_expert_num": moe_expert_num,
|
|
"global_bs": 0,
|
|
}
|
|
tp_recv_counts = torch.empty(1,
|
|
dtype=torch.int32,
|
|
device=hidden_states.device)
|
|
stage3_kwargs = {
|
|
"ep_send_counts": ep_recv_counts,
|
|
"group_ep": moe_all_to_all_group_name,
|
|
"ep_world_size": all_to_all_group_size,
|
|
"ep_rank_id": local_rank,
|
|
"tp_send_counts": tp_recv_counts,
|
|
# "group_tp": self.moe_rs_group_name,
|
|
"group_tp": moe_all_to_all_group_name,
|
|
"tp_world_size": tp_size,
|
|
"tp_rank_id": tp_rank,
|
|
}
|
|
kwargs_mc2.update(stage3_kwargs)
|
|
|
|
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
|
|
|
if multi_stream:
|
|
return hidden_states, shared_output
|
|
return hidden_states
|
|
|
|
|
|
# currently expert parallelism implemented with all2all
|
|
# is under-optimized.
|
|
def fused_experts_with_all2all(
|
|
hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
top_k: int,
|
|
expert_map: torch.Tensor = None,
|
|
ep_group: GroupCoordinator = None,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
):
|
|
topk_ids = log2phy[topk_ids]
|
|
original_shape = hidden_states.shape
|
|
if len(original_shape) == 3:
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
|
|
num_tokens, _ = hidden_states.shape
|
|
num_experts = w1.shape[0]
|
|
device = hidden_states.device
|
|
|
|
if expert_map is not None:
|
|
global_num_experts = len(expert_map) + global_redundant_expert_num
|
|
local_num_experts = global_num_experts // ep_group.world_size
|
|
row_idx_len = num_tokens * top_k
|
|
row_idx = (torch.arange(0,
|
|
row_idx_len,
|
|
dtype=torch.int32,
|
|
device=device).view(top_k, -1).permute(
|
|
1, 0).contiguous())
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
|
hidden_states,
|
|
row_idx=row_idx,
|
|
expert_idx=topk_ids,
|
|
active_num=num_tokens)
|
|
|
|
global_expert_tokens = torch.bincount(expanded_expert_idx,
|
|
minlength=global_num_experts)
|
|
scatter_sizes = global_expert_tokens.view(ep_group.world_size,
|
|
-1).sum(-1)
|
|
|
|
gather_sizes = torch.empty_like(scatter_sizes)
|
|
dist.all_to_all_single(gather_sizes,
|
|
scatter_sizes,
|
|
group=ep_group.device_group)
|
|
scatter_size_list = scatter_sizes.cpu().tolist()
|
|
gather_size_list = gather_sizes.cpu().tolist()
|
|
|
|
expanded_expert_idx = expanded_expert_idx % local_num_experts
|
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
|
scatter_size_list,
|
|
gather_size_list)
|
|
local_expert_idx = ep_group.all_to_all(expanded_expert_idx, 0, 0,
|
|
scatter_size_list,
|
|
gather_size_list)
|
|
|
|
sorted_local_expert_idx, sorted_idx = torch.sort(local_expert_idx)
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
|
sorted_local_expert_idx, local_num_experts).to(torch.int64)
|
|
|
|
hidden_states = hidden_states[sorted_idx]
|
|
group_list_type = 0
|
|
else:
|
|
row_idx_len = num_tokens * top_k
|
|
row_idx = torch.arange(0,
|
|
row_idx_len,
|
|
dtype=torch.int32,
|
|
device=topk_weights.device).view(
|
|
top_k, -1).permute(1, 0).contiguous()
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
|
hidden_states,
|
|
row_idx=row_idx,
|
|
expert_idx=topk_ids,
|
|
active_num=num_tokens)
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
|
expanded_expert_idx, num_experts)
|
|
expert_tokens = expert_tokens.to(torch.int64)
|
|
group_list_type = 0
|
|
|
|
# `hidden_states` will be disposed in the `apply_mlp` function
|
|
hidden_states = apply_mlp(
|
|
hidden_states,
|
|
w1,
|
|
w1_scale, #17
|
|
w2,
|
|
w2_scale,
|
|
expert_tokens, #16
|
|
group_list_type=group_list_type)
|
|
|
|
if expert_map is not None:
|
|
resorted_idx = torch.argsort(sorted_idx)
|
|
hidden_states = hidden_states[resorted_idx]
|
|
hidden_states = ep_group.all_to_all(hidden_states, 0, 0,
|
|
gather_size_list,
|
|
scatter_size_list)
|
|
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
hidden_states,
|
|
skip1=None,
|
|
skip2=None,
|
|
bias=None,
|
|
scales=topk_weights,
|
|
expanded_src_to_dst_row=expanded_row_idx,
|
|
export_for_source_row=topk_ids,
|
|
)
|
|
else:
|
|
# TODO: Reorder device memory 2 times here, replace the current
|
|
# implementation here when suitable operators become available.
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
hidden_states,
|
|
skip1=None,
|
|
skip2=None,
|
|
bias=None,
|
|
scales=topk_weights,
|
|
expanded_src_to_dst_row=expanded_row_idx,
|
|
export_for_source_row=topk_ids,
|
|
)
|
|
if len(original_shape) == 3:
|
|
final_hidden_states = final_hidden_states.view(original_shape)
|
|
return final_hidden_states
|
|
|
|
|
|
def fused_experts(hidden_states: torch.Tensor,
|
|
w1: torch.Tensor,
|
|
w1_scale: torch.Tensor,
|
|
w2: torch.Tensor,
|
|
w2_scale: torch.Tensor,
|
|
topk_weights: torch.Tensor,
|
|
topk_ids: torch.Tensor,
|
|
top_k: int,
|
|
expert_map: torch.Tensor = None):
|
|
original_shape = hidden_states.shape
|
|
if len(original_shape) == 3:
|
|
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
|
|
|
num_tokens, _ = hidden_states.shape
|
|
num_experts = w1.shape[0]
|
|
dtype = hidden_states.dtype
|
|
device = hidden_states.device
|
|
|
|
if expert_map is not None:
|
|
# Generate token indices and flatten
|
|
token_indices = (torch.arange(num_tokens,
|
|
device=device,
|
|
dtype=torch.int64).unsqueeze(1).expand(
|
|
-1, top_k).reshape(-1))
|
|
|
|
# Flatten token-to-expert mappings and map to local experts
|
|
weights_flat = topk_weights.view(-1)
|
|
experts_flat = topk_ids.view(-1)
|
|
local_experts_flat = expert_map[experts_flat]
|
|
|
|
# Filter valid token-expert pairs
|
|
mask = local_experts_flat != -1
|
|
filtered_weights = torch.where(
|
|
mask, weights_flat, torch.zeros_like(weights_flat)).to(dtype)
|
|
filtered_experts = torch.where(
|
|
mask, local_experts_flat,
|
|
torch.full_like(local_experts_flat,
|
|
num_experts)).to(topk_ids.dtype)
|
|
|
|
# Sort by local expert IDs
|
|
sort_indices = torch.argsort(filtered_experts)
|
|
sorted_token_indices = token_indices[sort_indices]
|
|
sorted_weights = filtered_weights[sort_indices]
|
|
|
|
# Compute token counts with minlength of num_experts
|
|
# This is equivalent to but faster than:
|
|
# >>> token_counts = torch.bincount(filtered_experts, minlength=num_experts)[:-1]
|
|
token_counts = torch.zeros(num_experts + 1,
|
|
device=device,
|
|
dtype=torch.int64)
|
|
ones = torch.ones_like(filtered_experts, dtype=torch.int64)
|
|
token_counts.scatter_add_(0, filtered_experts.to(torch.int64), ones)
|
|
expert_tokens = token_counts[:num_experts]
|
|
# Rearrange hidden_states
|
|
hidden_states = hidden_states[sorted_token_indices]
|
|
group_list_type = 1
|
|
else:
|
|
row_idx_len = num_tokens * top_k
|
|
row_idx = torch.arange(0,
|
|
row_idx_len,
|
|
dtype=torch.int32,
|
|
device=topk_weights.device).view(
|
|
top_k, -1).permute(1, 0).contiguous()
|
|
hidden_states, expanded_row_idx, expanded_expert_idx = torch_npu.npu_moe_init_routing(
|
|
hidden_states,
|
|
row_idx=row_idx,
|
|
expert_idx=topk_ids,
|
|
active_num=num_tokens)
|
|
|
|
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
|
expanded_expert_idx, num_experts)
|
|
expert_tokens = expert_tokens.to(torch.int64)
|
|
group_list_type = 0
|
|
|
|
# `hidden_states` will be disposed in the `apply_mlp` function
|
|
hidden_states = apply_mlp(hidden_states,
|
|
w1,
|
|
w1_scale,
|
|
w2,
|
|
w2_scale,
|
|
expert_tokens,
|
|
group_list_type=group_list_type)
|
|
|
|
if expert_map is not None:
|
|
hidden_states.mul_(sorted_weights.unsqueeze(1))
|
|
final_hidden_states = torch.zeros(*original_shape,
|
|
device=device,
|
|
dtype=dtype)
|
|
|
|
num_valid_tokens = mask.sum()
|
|
valid_token_mask = torch.arange(
|
|
0, sorted_token_indices.shape[0],
|
|
device=device).unsqueeze(1) < num_valid_tokens
|
|
hidden_states = hidden_states.masked_fill_(~valid_token_mask,
|
|
0).to(dtype)
|
|
final_hidden_states.index_add_(0, sorted_token_indices, hidden_states)
|
|
else:
|
|
# TODO: Reorder device memory 2 times here, replace the current
|
|
# implementation here when suitable operators become available.
|
|
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
|
hidden_states,
|
|
skip1=None,
|
|
skip2=None,
|
|
bias=None,
|
|
scales=topk_weights,
|
|
expanded_src_to_dst_row=expanded_row_idx,
|
|
export_for_source_row=topk_ids,
|
|
)
|
|
|
|
if len(original_shape) == 3:
|
|
final_hidden_states = final_hidden_states.view(original_shape)
|
|
return final_hidden_states
|
|
|
|
|
|
class AscendW8A8DynamicLinearMethod:
|
|
"""Linear method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
@staticmethod
|
|
def get_weight(input_size: int, output_size: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
params_dict = {
|
|
"weight": torch.empty(output_size, input_size, dtype=torch.int8)
|
|
}
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
return {}
|
|
|
|
@staticmethod
|
|
def get_perchannel_param(
|
|
output_size: int,
|
|
params_dtype: torch.dtype,
|
|
) -> Dict[str, Any]:
|
|
params_dict = {}
|
|
params_dict["weight_scale"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
params_dict["weight_offset"] = torch.empty(output_size,
|
|
1,
|
|
dtype=params_dtype)
|
|
return params_dict
|
|
|
|
@staticmethod
|
|
def apply(
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
bias: Optional[torch.Tensor] = None,
|
|
tp_rank: Optional[int] = 0,
|
|
) -> torch.Tensor:
|
|
original_dtype = x.dtype
|
|
# use ATB quantize
|
|
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
|
return torch_npu.npu_quant_matmul(
|
|
quant_out,
|
|
layer.weight,
|
|
layer.weight_scale,
|
|
pertoken_scale=dynamic_scale,
|
|
bias=bias,
|
|
output_dtype=original_dtype,
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.transpose_weight:
|
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
|
# cast quantized weight tensors in NZ format (29) for higher inference speed
|
|
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
|
|
|
|
|
class AscendW8A8DynamicFusedMoEMethod:
|
|
"""FusedMoe method for Ascend W8A8_DYNAMIC.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.transpose_weight = True
|
|
|
|
self.ep_group = get_ep_group()
|
|
|
|
ascend_config = get_ascend_config()
|
|
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
|
|
|
try:
|
|
device_group = self.ep_group.device_group
|
|
# TODO: Try local_rank = ep_group.rank_in_group
|
|
local_rank = torch.distributed.get_rank(group=device_group)
|
|
backend = device_group._get_backend(torch.device("npu"))
|
|
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
|
|
local_rank)
|
|
except AttributeError:
|
|
self.moe_all_to_all_group_name = ""
|
|
|
|
@staticmethod
|
|
def get_weight(num_experts: int, intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight"] = torch.empty(num_experts,
|
|
2 *
|
|
intermediate_size_per_partition,
|
|
hidden_sizes,
|
|
dtype=torch.int8)
|
|
param_dict["w2_weight"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
intermediate_size_per_partition,
|
|
dtype=torch.int8)
|
|
return param_dict
|
|
|
|
@staticmethod
|
|
def get_dynamic_quant_param(num_experts: int,
|
|
intermediate_size_per_partition: int,
|
|
hidden_sizes: int,
|
|
params_dtype: torch.dtype) -> Dict[str, Any]:
|
|
param_dict = {}
|
|
param_dict["w13_weight_scale"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w13_weight_offset"] = torch.empty(
|
|
num_experts,
|
|
2 * intermediate_size_per_partition,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_scale"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
param_dict["w2_weight_offset"] = torch.empty(num_experts,
|
|
hidden_sizes,
|
|
1,
|
|
dtype=params_dtype)
|
|
return param_dict
|
|
|
|
def apply(
|
|
self,
|
|
layer: torch.nn.Module,
|
|
x: torch.Tensor,
|
|
router_logits: torch.Tensor,
|
|
top_k: int,
|
|
renormalize: bool,
|
|
use_grouped_topk: bool = False,
|
|
global_num_experts: int = -1,
|
|
expert_map: Optional[torch.Tensor] = None,
|
|
topk_group: Optional[int] = None,
|
|
num_expert_group: Optional[int] = None,
|
|
custom_routing_function: Optional[Callable] = None,
|
|
scoring_func: str = "softmax",
|
|
e_score_correction_bias: Optional[torch.Tensor] = None,
|
|
is_prefill: bool = True,
|
|
enable_force_load_balance: bool = True,
|
|
log2phy: torch.Tensor = None,
|
|
global_redundant_expert_num: int = 0,
|
|
**kwargs,
|
|
) -> torch.Tensor:
|
|
assert router_logits.shape[
|
|
1] == global_num_experts, "Number of global experts mismatch"
|
|
|
|
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
|
if global_num_experts == 256:
|
|
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
|
router_logits,
|
|
k=top_k, # topk当前写8
|
|
bias=e_score_correction_bias,
|
|
k_group=topk_group, # fix: 4
|
|
group_count=num_expert_group, # fix 8
|
|
group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
|
|
renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
|
|
norm_type=1, # 0: softmax; 1: sigmoid(fix)
|
|
# out_flag=False, # todo new api; 第三个输出是否输出
|
|
# y2_flag=False, # old api; 第三个输出是否输出
|
|
routed_scaling_factor=1,
|
|
eps=float(1e-20))
|
|
else:
|
|
topk_weights, topk_ids = select_experts(
|
|
hidden_states=x,
|
|
router_logits=router_logits,
|
|
top_k=top_k,
|
|
use_grouped_topk=use_grouped_topk,
|
|
renormalize=renormalize,
|
|
topk_group=topk_group,
|
|
num_expert_group=num_expert_group,
|
|
custom_routing_function=custom_routing_function,
|
|
scoring_func=scoring_func,
|
|
e_score_correction_bias=e_score_correction_bias,
|
|
)
|
|
|
|
# this is a naive implementation for experts load balance so as
|
|
# to avoid accumulating too much tokens on a single rank.
|
|
# currently it is only activated when doing profile runs.
|
|
if enable_force_load_balance:
|
|
topk_ids = torch.randint_like(topk_ids, 0, global_num_experts)
|
|
|
|
topk_weights = topk_weights.to(x.dtype)
|
|
|
|
if VLLM_ENABLE_MC2 and not is_prefill:
|
|
return fused_experts_with_mc2(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w2=layer.w2_weight,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2_scale=layer.w2_weight_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
top_k=top_k,
|
|
expert_map=expert_map,
|
|
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
|
log2phy=log2phy,
|
|
global_redundant_expert_num=global_redundant_expert_num,
|
|
**kwargs)
|
|
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
|
return fused_experts(hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2=layer.w2_weight,
|
|
w2_scale=layer.w2_weight_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
top_k=top_k,
|
|
expert_map=expert_map)
|
|
else:
|
|
# The current implementation of deepseek moe splits hidden_states
|
|
# according to tp_size before they are feed into fused_moe module.
|
|
# Therefore, all2all is needed no matter how dp/tp is set so as to
|
|
# dispatch/combine tokens.
|
|
return fused_experts_with_all2all(
|
|
hidden_states=x,
|
|
w1=layer.w13_weight,
|
|
w1_scale=layer.w13_weight_scale,
|
|
w2=layer.w2_weight,
|
|
w2_scale=layer.w2_weight_scale,
|
|
topk_weights=topk_weights,
|
|
topk_ids=topk_ids,
|
|
top_k=top_k,
|
|
expert_map=expert_map,
|
|
ep_group=self.ep_group,
|
|
log2phy=log2phy,
|
|
global_redundant_expert_num=global_redundant_expert_num,
|
|
)
|
|
|
|
def process_weights_after_loading(self, layer):
|
|
if self.transpose_weight:
|
|
layer.w13_weight.data = layer.w13_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w2_weight.data = layer.w2_weight.data.transpose(
|
|
1, 2).contiguous()
|
|
layer.w13_weight_scale.data = layer.w13_weight_scale.data.view(
|
|
layer.w13_weight_scale.data.shape[0], -1)
|
|
layer.w13_weight_offset.data = layer.w13_weight_offset.data.view(
|
|
layer.w13_weight_offset.data.shape[0], -1)
|
|
layer.w2_weight_scale.data = layer.w2_weight_scale.data.view(
|
|
layer.w2_weight_scale.data.shape[0], -1)
|
|
layer.w2_weight_offset.data = layer.w2_weight_offset.data.view(
|
|
layer.w2_weight_offset.data.shape[0], -1)
|