### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -14,26 +14,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import get_weight_prefetch_method
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
global_num_experts: int = -1,
|
||||
):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
@@ -58,8 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
|
||||
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
|
||||
hidden_states=hidden_states,
|
||||
top_k=top_k,
|
||||
@@ -67,7 +68,8 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
custom_routing_function=custom_routing_function)
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
if is_support_npu_moe_gating_top_k:
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
@@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -100,14 +103,15 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
|
||||
|
||||
def check_npu_moe_gating_top_k(
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
scoring_func: str = "softmax",
|
||||
custom_routing_function: Optional[Callable] = None):
|
||||
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
custom_routing_function: Callable | None = None,
|
||||
):
|
||||
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
|
||||
return False
|
||||
if custom_routing_function is not None:
|
||||
return False
|
||||
@@ -115,39 +119,39 @@ def check_npu_moe_gating_top_k(
|
||||
return False
|
||||
topk_group = topk_group if topk_group is not None else 1
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
|
||||
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
|
||||
if not (
|
||||
num_expert_group > 0
|
||||
and hidden_states.shape[-1] % num_expert_group == 0
|
||||
and hidden_states.shape[-1] // num_expert_group > 2
|
||||
):
|
||||
return False
|
||||
if topk_group < 1 or topk_group > num_expert_group:
|
||||
return False
|
||||
if top_k < 1 or \
|
||||
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
||||
if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
||||
return False
|
||||
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
|
||||
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weight_mask = (
|
||||
topk_group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
)
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
@@ -163,9 +167,13 @@ def _renormalize_topk_weights(
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_group: int | None,
|
||||
renormalize: bool,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
@@ -177,47 +185,38 @@ def _select_expert_use_group_topk(
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
topk_group: int | None,
|
||||
num_expert_group: int | None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1,
|
||||
):
|
||||
topk_group = topk_group if topk_group is not None else 1
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||
renorm = int(renormalize)
|
||||
norm_type = 0 if scoring_func == "softmax" else 1
|
||||
if e_score_correction_bias is not None and \
|
||||
e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(
|
||||
router_logits.dtype)
|
||||
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
|
||||
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
@@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops(
|
||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||
out_flag=False,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20),
|
||||
eps=1e-20,
|
||||
bias_opt=e_score_correction_bias,
|
||||
)
|
||||
|
||||
@@ -241,12 +240,12 @@ def _native_select_experts(
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
global_num_experts: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -285,7 +284,8 @@ def _native_select_experts(
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
@@ -293,7 +293,8 @@ def _native_select_experts(
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
@@ -318,8 +319,7 @@ def zero_experts_compute(
|
||||
if zero_expert_type == "identity":
|
||||
zero_expert_mask = expert_indices < num_experts
|
||||
zero_expert_scales = expert_scales.clone()
|
||||
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
|
||||
zero_expert_scales)
|
||||
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
zero_expert_scales = zero_expert_scales.unsqueeze(2)
|
||||
|
||||
Reference in New Issue
Block a user