port deepseekv2 and mtp to main branch (#429)

### What this PR does / why we need it?
This PR ports all the deepseek graph mode code and mtp code from v0.7.3
to the main branch
---------

Signed-off-by: SidaoY <1024863041@qq.com>
Signed-off-by: linfeng-yuan <1102311262@qq.com>
Signed-off-by: Yizhou Liu <liuyizhou5@h-partners.com>
Signed-off-by: mengwei805 <mengwei25@huawei.com>
Signed-off-by: libaokui <libaokui@huawei.com>
Signed-off-by: q00832892 <qiaoyang19@huawei.com>
Signed-off-by: ganyi <pleaplusone.gy@gmail.com>
Co-authored-by: SidaoY <1024863041@qq.com>
Co-authored-by: linfeng-yuan <1102311262@qq.com>
Co-authored-by: Yizhou Liu <liuyizhou5@h-partners.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: libaokui <libaokui@huawei.com>
This commit is contained in:
Pleaplusone
2025-04-19 17:38:18 +08:00
committed by GitHub
parent 086423dc35
commit 1a1f9a6d89
33 changed files with 3361 additions and 315 deletions

View File

@@ -15,12 +15,131 @@
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/kernels/test_moe.py
import os
from typing import Callable, Optional
import torch
import torch.distributed as dist
import torch_npu
from vllm.model_executor.layers.fused_moe.layer import \
UnquantizedFusedMoEMethod
from vllm.config import get_current_vllm_config
from vllm.distributed import tensor_model_parallel_all_reduce
from vllm.distributed.parallel_state import get_dp_group
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, determine_expert_map)
from vllm.model_executor.layers.quantization.base_config import \
QuantizeMethodBase
from vllm_ascend.distributed.parallel_state import get_ep_group, get_etp_group
def fused_experts_with_mc2(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
expert_map: torch.Tensor = None,
moe_all_to_all_group_name: Optional[str] = None,
) -> torch.Tensor:
global_bs = 0
moe_expert_num = len(expert_map)
kwargs = {
"x": hidden_states,
"expert_ids": topk_ids,
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": global_bs,
}
rank = torch.distributed.get_rank()
quant_mode = 0
ep_group = get_ep_group().device_group
local_rank = torch.distributed.get_rank(group=ep_group)
all_to_all_group_size = torch.distributed.get_world_size(ep_group)
world_szie = torch.distributed.get_world_size()
tp_size = world_szie // all_to_all_group_size
tp_rank = rank % tp_size
stage1_kwargs = {
"scales": None,
"quant_mode": quant_mode,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage1_kwargs)
output = torch_npu.npu_moe_distribute_dispatch(**kwargs)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
0:5]
w1 = w1.transpose(1, 2)
expert_token_nums = torch.cumsum(expert_token_nums,
dim=0,
dtype=torch.int64)
group_list = expert_token_nums.to(torch.int64)
gate_up_out_list = torch_npu.npu_grouped_matmul(
x=[expand_x],
weight=[w1],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)
# TODO: Remove this in the future.
gate_up_out = torch.cat(gate_up_out_list, dim=0)
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
w2 = w2.transpose(1, 2)
down_out_list = torch_npu.npu_grouped_matmul(
x=[gate_up_out],
weight=[w2],
split_item=2,
group_list_type=0,
group_type=0,
group_list=group_list,
)
down_out_list = torch.cat(down_out_list, dim=0)
# moeCombine
kwargs = {
"expand_x": down_out_list,
"expert_ids": topk_ids,
"expand_idx": expand_idx,
"expert_scales": topk_weights.to(torch.float32),
"expert_shard_type": 0,
"shared_expert_rank_num": 0,
"moe_expert_num": moe_expert_num,
"global_bs": 0,
}
tp_recv_counts = output[5]
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
"group_ep": moe_all_to_all_group_name,
"ep_world_size": all_to_all_group_size,
"ep_rank_id": local_rank,
"tp_send_counts": tp_recv_counts,
# "group_tp": self.moe_rs_group_name,
"group_tp": moe_all_to_all_group_name,
"tp_world_size": tp_size,
"tp_rank_id": tp_rank,
}
kwargs.update(stage3_kwargs)
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs)
return hidden_states
def fused_experts(
@@ -47,22 +166,27 @@ def fused_experts(
Returns:
hidden_states: Hidden states after routing.
"""
"""
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
"""
# if torch.distributed.get_rank() == 0:
# print(w1.shape)
# print(hidden_states.shape)
original_shape = hidden_states.shape
assert len(original_shape) == 2
# assert len(original_shape) == 2
num_tokens = hidden_states.shape[:-1].numel()
num_experts = w1.shape[0]
dtype = hidden_states.dtype
device = hidden_states.device
assert dtype in [torch.float32, torch.float16, torch.bfloat16
], "Only float32, float16, and bfloat16 are supported"
# assert dtype in [torch.float32, torch.float16, torch.bfloat16
# ], "Only float32, float16, and bfloat16 are supported"
if expert_map is not None:
# Generate token indices and flatten
@@ -152,11 +276,18 @@ def fused_experts(
final_hidden_states = torch.zeros(*original_shape,
device=hidden_states.device,
dtype=dtype)
final_hidden_states.index_add_(0, sorted_token_indices,
weighted_down_out)
# TODO: This should not happen! Look into it!
# fill nan with 0.0
final_hidden_states[torch.isnan(final_hidden_states)] = 0.0
# TODO: npu_grouped_matmul output random values at [num_valid_tokens:, ...]
# This created multiple NaN and index_add_ will mix them up which harms accracy
# remove this mask and filter after it being fixed
num_valid_tokens = mask.sum()
valid_token_mask = torch.arange(
0, sorted_token_indices.shape[0],
device=device).unsqueeze(1) < num_valid_tokens
valid_output = torch.where(
valid_token_mask, weighted_down_out,
torch.zeros_like(weighted_down_out)).to(dtype)
final_hidden_states.index_add_(0, sorted_token_indices, valid_output)
else:
# TODO: Reorder device memory 2 times here, replace the current
# implementation here when suitable operators become available.
@@ -199,16 +330,17 @@ def native_grouped_topk(
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill: Optional[bool] = True
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
@@ -232,8 +364,23 @@ def select_experts(
Raises:
ValueError: If an unsupported scoring function is provided.
"""
assert hidden_states.shape[0] == router_logits.shape[0], (
"Number of tokens mismatch")
# assert hidden_states.shape[0] == router_logits.shape[0], (
# "Number of tokens mismatch")
# if os.environ.get("VLLM_ENABLE_GRAPH_MODE") == "1" and not is_prefill:
# topk_weight, topk_idx, _ = torch.ops.npu_inference.npu_moe_gating_top_k(
# router_logits,
# k=top_k, # topk当前写8
# bias=e_score_correction_bias,
# k_group=topk_group, # fix: 4
# group_count=num_expert_group, # fix 8
# group_select_mode=1, # 0: group中的最大; 1: topk2.sum(fix)
# renorm=0, # 0: softmax->topk(fix); 1: topk->softmax
# norm_type=1, # 0: softmax; 1: sigmoid(fix)
# # out_flag=False, # todo new api; 第三个输出是否输出
# # y2_flag=False, # old api; 第三个输出是否输出
# routed_scaling_factor=1,
# eps=float(1e-20))
# return topk_weight, topk_idx
if custom_routing_function is not None:
raise NotImplementedError(
@@ -261,14 +408,16 @@ def select_experts(
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = native_grouped_topk(topk_weights, num_expert_group,
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights, k=top_k, dim=-1,
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights,
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
@@ -285,46 +434,245 @@ def select_experts(
return topk_weights, topk_ids
def forward_oot(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
**kwargs,
):
assert router_logits.shape[
1] == global_num_experts, "Number of global experts mismatch"
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
)
def __init__(self):
super().__init__()
vllm_config = get_current_vllm_config()
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
ep_group = get_ep_group()
self.ep_size = ep_group.world_size
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
self.local_batch_size = self.global_batch_size // self.ep_size
try:
device_group = ep_group.device_group
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=device_group)
backend = device_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(
local_rank)
except AttributeError:
self.moe_all_to_all_group_name = None
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
layer.w13_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w13_weight.data),
requires_grad=False)
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data),
requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
is_prefill=False,
**kwargs,
):
# assert router_logits.shape[
# 1] == global_num_experts, "Number of global experts mismatch"
# set prefill as false always, should fix this
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
top_k=top_k,
use_grouped_topk=use_grouped_topk,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
is_prefill=is_prefill)
if os.environ.get("VLLM_ENABLE_MC2") == "1" and not is_prefill:
return fused_experts_with_mc2(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map,
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
else:
return fused_experts(hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
expert_map=expert_map)
UnquantizedFusedMoEMethod.forward_oot = forward_oot
class AscendFusedMoE(FusedMoE):
def __init__(self,
num_experts,
top_k,
hidden_size,
intermediate_size,
params_dtype=None,
reduce_results=False,
renormalize=True,
use_grouped_topk=False,
num_expert_group=None,
topk_group=None,
quant_config=None,
tp_size=None,
ep_size=None,
dp_size=None,
prefix="",
custom_routing_function=None,
scoring_func="softmax",
e_score_correction_bias=None,
activation="silu"):
super(FusedMoE, self).__init__()
if params_dtype is None:
params_dtype = torch.get_default_dtype()
self.ep_size = get_ep_group().world_size
self.tp_size = get_etp_group().world_size
self.dp_size = (dp_size
if dp_size is not None else get_dp_group().world_size)
self.dp_rank = (0
if self.dp_size == 1 else get_dp_group().rank_in_group)
self.top_k = top_k
self.num_experts = num_experts
self.global_num_experts = num_experts
assert intermediate_size % self.tp_size == 0
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.reduce_results = reduce_results
self.renormalize = renormalize
self.use_grouped_topk = use_grouped_topk
if self.use_grouped_topk:
assert num_expert_group is not None and topk_group is not None
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.custom_routing_function = custom_routing_function
self.scoring_func = scoring_func
self.e_score_correction_bias = e_score_correction_bias
self.expert_map = None
self.activation = activation
if self.ep_size > 1:
# Create a tensor of size num_experts filled with -1
self.local_num_experts, self.expert_map = determine_expert_map(
self.ep_size,
get_ep_group().rank_in_group, self.global_num_experts)
self.tp_rank = get_etp_group().rank_in_group
self.ep_rank = get_ep_group().rank_in_group
else:
# Adjust TP size for DP attention
# haven't test its functionality yet, may remove in the future
self.tp_rank = self.tp_size * self.dp_rank
self.ep_rank = 0
self.tp_size = self.tp_size * self.dp_size
self.ep_size = 1
self.local_num_experts = self.global_num_experts
self.expert_map = None
if self.scoring_func != "softmax" and not self.use_grouped_topk:
raise ValueError("Only softmax scoring function is supported for "
"non-grouped topk.")
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = (
AscendUnquantizedFusedMoEMethod())
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
local_num_experts = torch.sum(self.expert_map != -1) \
if self.expert_map is not None else num_experts
moe_quant_params = {
"num_experts": local_num_experts,
"hidden_size": hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
def forward(self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
is_prefill: bool,
top_k=None):
assert self.quant_method is not None
if top_k:
real_top_k = top_k
else:
real_top_k = self.top_k
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
) == 1 and not is_prefill:
...
elif int(os.environ.get("USING_LCCL_COM")) == 1: # type: ignore
hidden_states = get_dp_group().all_gather(
hidden_states, 0, False)
router_logits = get_dp_group().all_gather(
router_logits, 0, False)
else:
hidden_states = get_dp_group().all_gather(hidden_states, 0)
router_logits = get_dp_group().all_gather(router_logits, 0)
# Matrix multiply.
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
router_logits=router_logits,
top_k=real_top_k,
renormalize=self.renormalize,
use_grouped_topk=self.use_grouped_topk,
global_num_experts=self.num_experts,
expert_map=self.expert_map,
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
is_prefill=is_prefill)
if self.dp_size > 1:
if int(os.environ.get("VLLM_ENABLE_MC2") # type: ignore
) == 1 and not is_prefill:
...
else:
final_hidden_states = dist._functional_collectives.reduce_scatter_tensor(
final_hidden_states,
"sum",
scatter_dim=0,
group=get_dp_group().device_group)
# if self.reduce_results and self.tp_size > 1:
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
return final_hidden_states