[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #11) (#6176)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 15:28:49 +08:00
committed by GitHub
parent 4fb3d5e1b2
commit 65b7f716e6
8 changed files with 694 additions and 784 deletions

View File

@@ -62,9 +62,6 @@ exclude = [
"vllm_ascend/worker/v2/**", "vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py", "vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py", "vllm_ascend/ops/rotary_embedding.py",
# (11)
"vllm_ascend/ops/fused_moe/**",
] ]
[tool.ruff.lint] [tool.ruff.lint]

View File

@@ -23,11 +23,7 @@ import torch_npu
COMM_STREAM = None COMM_STREAM = None
def async_all_to_all(input_, def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
output_split_sizes,
input_split_sizes,
group,
event=None):
if output_split_sizes is None: if output_split_sizes is None:
# Equal split (all2all) # Equal split (all2all)
a2a_out = torch.empty_like(input_) a2a_out = torch.empty_like(input_)
@@ -43,8 +39,7 @@ def async_all_to_all(input_,
# multi stream wait event # multi stream wait event
global COMM_STREAM global COMM_STREAM
if COMM_STREAM is None: if COMM_STREAM is None:
COMM_STREAM = torch_npu.npu.Stream( COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
device=torch.npu.current_device())
with torch_npu.npu.stream(COMM_STREAM): with torch_npu.npu.stream(COMM_STREAM):
event.wait() event.wait()
handle = dist.all_to_all_single( handle = dist.all_to_all_single(
@@ -53,14 +48,17 @@ def async_all_to_all(input_,
output_split_sizes=output_split_sizes, output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes, input_split_sizes=input_split_sizes,
group=group, group=group,
async_op=True) async_op=True,
)
else: else:
handle = dist.all_to_all_single(a2a_out, handle = dist.all_to_all_single(
input_.contiguous(), a2a_out,
output_split_sizes=output_split_sizes, input_.contiguous(),
input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes,
group=group, input_split_sizes=input_split_sizes,
async_op=True) group=group,
async_op=True,
)
return input_, a2a_out, handle return input_, a2a_out, handle
@@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None):
if output_split_sizes is None: if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size, output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
dtype=input_.dtype, torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group)
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
else: else:
dim_size[0] = sum(output_split_sizes) dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size, output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
dtype=input_.dtype, output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group) torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output return output

View File

@@ -14,26 +14,28 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from typing import Callable, Optional from collections.abc import Callable
import torch import torch
from vllm_ascend.utils import get_weight_prefetch_method from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(hidden_states: torch.Tensor, def select_experts(
router_logits: torch.Tensor, hidden_states: torch.Tensor,
top_k: int, router_logits: torch.Tensor,
use_grouped_topk: bool, top_k: int,
renormalize: bool, use_grouped_topk: bool,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None, topk_group: int | None = None,
custom_routing_function: Optional[Callable] = None, num_expert_group: int | None = None,
scoring_func: str = "softmax", custom_routing_function: Callable | None = None,
routed_scaling_factor=1.0, scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor=1.0,
indices_type: Optional[torch.dtype] = None, e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: int = -1): indices_type: torch.dtype | None = None,
global_num_experts: int = -1,
):
""" """
Fused experts with select experts. Fused experts with select experts.
@@ -58,8 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
# prefetch w1_w3_proj.weight preprocess # prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess( weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k( is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states, hidden_states=hidden_states,
top_k=top_k, top_k=top_k,
@@ -67,7 +68,8 @@ def select_experts(hidden_states: torch.Tensor,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
scoring_func=scoring_func, scoring_func=scoring_func,
custom_routing_function=custom_routing_function) custom_routing_function=custom_routing_function,
)
if is_support_npu_moe_gating_top_k: if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops( topk_weights, topk_ids = _select_experts_with_fusion_ops(
@@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
else: else:
topk_weights, topk_ids = _native_select_experts( topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states, hidden_states=hidden_states,
@@ -100,14 +103,15 @@ def select_experts(hidden_states: torch.Tensor,
def check_npu_moe_gating_top_k( def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
top_k: int, top_k: int,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None): custom_routing_function: Callable | None = None,
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch ):
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
return False return False
if custom_routing_function is not None: if custom_routing_function is not None:
return False return False
@@ -115,39 +119,39 @@ def check_npu_moe_gating_top_k(
return False return False
topk_group = topk_group if topk_group is not None else 1 topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group if not (
== 0 and hidden_states.shape[-1] // num_expert_group > 2): num_expert_group > 0
and hidden_states.shape[-1] % num_expert_group == 0
and hidden_states.shape[-1] // num_expert_group > 2
):
return False return False
if topk_group < 1 or topk_group > num_expert_group: if topk_group < 1 or topk_group > num_expert_group:
return False return False
if top_k < 1 or \ if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
return False return False
return True return True
def _native_grouped_topk( def _native_grouped_topk(
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
num_expert_group: Optional[int], num_expert_group: int | None,
topk_group: Optional[int], topk_group: int | None,
): ):
topk_group = 0 if topk_group is None else topk_group topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0] num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group, grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
-1).max(dim=-1).values topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights) topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1) topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand( topk_weight_mask = (
num_token, num_expert_group, topk_group_mask.unsqueeze(-1)
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1)) .expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0) topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights return topk_weights
@@ -163,9 +167,13 @@ def _renormalize_topk_weights(
def _select_expert_use_group_topk( def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int], topk_weights: torch.Tensor,
renormalize: bool, top_k: int, num_expert_group: Optional[int], topk_group: int | None,
e_score_correction_bias: Optional[torch.Tensor]): renormalize: bool,
top_k: int,
num_expert_group: int | None,
e_score_correction_bias: torch.Tensor | None,
):
assert topk_group is not None assert topk_group is not None
assert num_expert_group is not None assert num_expert_group is not None
@@ -177,47 +185,38 @@ def _select_expert_use_group_topk(
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available # TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group) # >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph. # TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None: if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32), topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
k=top_k,
dim=-1,
sorted=False)[1]
# Use original unbiased scores for the routing weights # Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids) topk_weights = original_weights.gather(1, topk_ids)
else: else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
k=top_k,
dim=-1,
sorted=False)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize) topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids return topk_weights, topk_ids
def _select_experts_with_fusion_ops( def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
router_logits: torch.Tensor, router_logits: torch.Tensor,
top_k: int, top_k: int,
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor], e_score_correction_bias: torch.Tensor | None,
topk_group: Optional[int], topk_group: int | None,
num_expert_group: Optional[int], num_expert_group: int | None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
routed_scaling_factor=1.0, routed_scaling_factor=1.0,
global_num_experts: int = -1): global_num_experts: int = -1,
):
topk_group = topk_group if topk_group is not None else 1 topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1 num_expert_group = num_expert_group if num_expert_group is not None else 1
renorm = int(renormalize) renorm = int(renormalize)
norm_type = 0 if scoring_func == "softmax" else 1 norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \ if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias.dtype != router_logits.dtype: e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k( topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
router_logits, router_logits,
k=top_k, k=top_k,
@@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops(
norm_type=norm_type, # 0: softmax; 1: sigmoid norm_type=norm_type, # 0: softmax; 1: sigmoid
out_flag=False, out_flag=False,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20), eps=1e-20,
bias_opt=e_score_correction_bias, bias_opt=e_score_correction_bias,
) )
@@ -241,12 +240,12 @@ def _native_select_experts(
top_k: int, top_k: int,
use_grouped_topk: bool, use_grouped_topk: bool,
renormalize: bool, renormalize: bool,
topk_group: Optional[int] = None, topk_group: int | None = None,
num_expert_group: Optional[int] = None, num_expert_group: int | None = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Callable | None = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: Optional[torch.Tensor] = None global_num_experts: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
""" """
Select top-k experts based on router logits. Select top-k experts based on router logits.
@@ -285,7 +284,8 @@ def _native_select_experts(
renormalize=renormalize, renormalize=renormalize,
topk_group=topk_group, topk_group=topk_group,
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
)
if custom_routing_function is not None: if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function( topk_weights, topk_ids = custom_routing_function(
@@ -293,7 +293,8 @@ def _native_select_experts(
gating_output=router_logits, gating_output=router_logits,
topk=top_k, topk=top_k,
renormalize=renormalize, renormalize=renormalize,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing # Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids return topk_weights, topk_ids
@@ -318,8 +319,7 @@ def zero_experts_compute(
if zero_expert_type == "identity": if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone() zero_expert_scales = expert_scales.clone()
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)
zero_expert_scales)
hidden_states = hidden_states.unsqueeze(1) hidden_states = hidden_states.unsqueeze(1)
zero_expert_scales = zero_expert_scales.unsqueeze(2) zero_expert_scales = zero_expert_scales.unsqueeze(2)

View File

@@ -14,40 +14,37 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# #
from collections.abc import Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from functools import wraps from functools import wraps
from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.config import get_current_vllm_config from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group, from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
tensor_model_parallel_all_reduce)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.logger import logger from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import ( from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map) from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm.model_executor.layers.fused_moe.shared_fused_moe import \
SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import (get_flash_common3_context, from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
set_flash_common3_context) from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts, from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
zero_experts_compute)
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.utils import (AscendDeviceType, enable_sp, from vllm_ascend.utils import (
get_ascend_device_type, maybe_trans_nz, enable_sp,
npu_stream_switch, shared_expert_dp_enabled, maybe_trans_nz,
shared_experts_calculation_stream, npu_stream_switch,
vllm_version_is) shared_expert_dp_enabled,
shared_experts_calculation_stream,
vllm_version_is,
)
@dataclass @dataclass
class FusedMoEResult: class FusedMoEResult:
@@ -64,46 +61,43 @@ class FusedMoEEvents:
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod): class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None): def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe) super().__init__(moe=moe)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def process_weights_after_loading(self, layer): def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod, super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose( w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous()
1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose( w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data) layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data) layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
def apply(self, def apply(
layer: torch.nn.Module, self,
x: torch.Tensor, layer: torch.nn.Module,
use_grouped_topk: bool, x: torch.Tensor,
top_k: int, use_grouped_topk: bool,
router_logits: torch.Tensor, top_k: int,
renormalize: bool, router_logits: torch.Tensor,
topk_group: Optional[int] = None, renormalize: bool,
num_expert_group: Optional[int] = None, topk_group: int | None = None,
custom_routing_function: Optional[Callable] = None, num_expert_group: int | None = None,
scoring_func: str = "softmax", custom_routing_function: Callable | None = None,
routed_scaling_factor: float = 1.0, scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None, routed_scaling_factor: float = 1.0,
global_num_experts: int = -1, e_score_correction_bias: torch.Tensor | None = None,
expert_map: Optional[torch.Tensor] = None, global_num_experts: int = -1,
apply_router_weight_on_input: bool = False, expert_map: torch.Tensor | None = None,
enable_force_load_balance: bool = False, apply_router_weight_on_input: bool = False,
log2phy: torch.Tensor = None, enable_force_load_balance: bool = False,
**kwargs) -> torch.Tensor: log2phy: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0) zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None) zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
@@ -118,7 +112,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
scoring_func=scoring_func, scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias, e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts) global_num_experts=global_num_experts,
)
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute( topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
@@ -134,11 +129,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# to avoid accumulating too much tokens on a single rank. # to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs. # currently it is only activated when doing profile runs.
if enable_force_load_balance: if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0), random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
global_num_experts, topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
final_hidden_states = moe_comm_method.fused_experts( final_hidden_states = moe_comm_method.fused_experts(
@@ -151,7 +143,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb, dynamic_eplb=self.dynamic_eplb,
log2phy=log2phy, log2phy=log2phy,
mc2_mask=kwargs.get("mc2_mask", None)) mc2_mask=kwargs.get("mc2_mask"),
)
if zero_expert_num > 0 and zero_expert_type is not None: if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result final_hidden_states += zero_expert_result
return final_hidden_states return final_hidden_states
@@ -159,7 +152,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendFusedMoE(FusedMoE): class AscendFusedMoE(FusedMoE):
moe_counter = -1 moe_counter = -1
gate_stream: Optional[torch.npu.Stream] = None gate_stream: torch.npu.Stream | None = None
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@@ -174,11 +167,9 @@ class AscendFusedMoE(FusedMoE):
self.log2phy = None self.log2phy = None
if self.quant_config is None: if self.quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod( self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe_config)
self.moe_config)
else: else:
self.quant_method = self.quant_config.get_quant_method( self.quant_method = self.quant_config.get_quant_method(self, self.layer_name)
self, self.layer_name)
assert self.quant_method is not None assert self.quant_method is not None
@@ -195,28 +186,32 @@ class AscendFusedMoE(FusedMoE):
if self.custom_routing_function is None and self.e_score_correction_bias is not None: if self.custom_routing_function is None and self.e_score_correction_bias is not None:
vllm_config = get_current_vllm_config() vllm_config = get_current_vllm_config()
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to( self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
dtype=vllm_config.model_config.dtype) dtype=vllm_config.model_config.dtype
)
# init moe # init moe
eplb_config = ascend_config.eplb_config eplb_config = ascend_config.eplb_config
self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config( self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
eplb_config, self.moe_instance_id, self.moe_config) eplb_config, self.moe_instance_id, self.moe_config
)
self.global_num_experts = num_experts + self.global_redundant_expert_num self.global_num_experts = num_experts + self.global_redundant_expert_num
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None)
is not None) self.local_num_experts = (
self.local_num_experts = (torch.sum( torch.sum(self._expert_map != -1).item() if self._expert_map is not None else self.global_num_experts
self._expert_map != -1).item() if self._expert_map is not None else )
self.global_num_experts)
if self._expert_map is not None: if self._expert_map is not None:
logger.info_once( logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global" "[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:" " number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts, " %s.",
self.ep_rank,
self.ep_size,
self.local_num_experts,
self.global_num_experts, self.global_num_experts,
get_compressed_expert_map(self._expert_map)) get_compressed_expert_map(self._expert_map),
)
if self.dynamic_eplb: if self.dynamic_eplb:
self.moe_load = torch.zeros(self.local_num_experts, self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
dtype=torch.int64).npu()
self.moe_config.num_experts = self.global_num_experts self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts self.moe_config.num_local_experts = self.local_num_experts
@@ -225,14 +220,12 @@ class AscendFusedMoE(FusedMoE):
moe_quant_params = { moe_quant_params = {
"num_experts": self.local_num_experts, "num_experts": self.local_num_experts,
"hidden_size": self.hidden_size, "hidden_size": self.hidden_size,
"intermediate_size_per_partition": "intermediate_size_per_partition": self.intermediate_size_per_partition,
self.intermediate_size_per_partition,
"params_dtype": self.params_dtype, "params_dtype": self.params_dtype,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
} }
# need full intermediate size pre-sharding for WNA16 act order # need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"):
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -243,15 +236,14 @@ class AscendFusedMoE(FusedMoE):
def _get_quant_type(self) -> QuantType: def _get_quant_type(self) -> QuantType:
quant_method = self.quant_method quant_method = self.quant_method
if not hasattr(quant_method, if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
"quant_method") or quant_method.quant_method is None:
return QuantType.NONE return QuantType.NONE
method = quant_method.quant_method method = quant_method.quant_method
if hasattr(method, "quant_type"): if hasattr(method, "quant_type"):
from vllm_ascend.quantization.methods.base import \ from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType
QuantType as SchemeQuantType
scheme_quant_type = method.quant_type scheme_quant_type = method.quant_type
if scheme_quant_type == SchemeQuantType.W8A8: if scheme_quant_type == SchemeQuantType.W8A8:
return QuantType.W8A8 return QuantType.W8A8
@@ -270,22 +262,18 @@ class AscendFusedMoE(FusedMoE):
if self.moe_load is not None: if self.moe_load is not None:
self.moe_load.zero_() self.moe_load.zero_()
def maybe_all_reduce_tensor_model_parallel( def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`, """NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the `finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs. outputs since each rank only has partial outputs.
""" """
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel( return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
final_hidden_states)
def forward_impl( # type: ignore[override] def forward_impl( # type: ignore[override]
self, self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False
hidden_states: torch.Tensor, ) -> torch.Tensor | FusedMoEResult:
router_logits: torch.Tensor,
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
assert self.quant_method is not None assert self.quant_method is not None
forward_context = get_forward_context() forward_context = get_forward_context()
@@ -301,15 +289,16 @@ class AscendFusedMoE(FusedMoE):
fc3_context = get_flash_common3_context() fc3_context = get_flash_common3_context()
assert fc3_context is not None assert fc3_context is not None
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream()) AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(AscendFusedMoE.gate_stream, with npu_stream_switch(AscendFusedMoE.gate_stream, enabled=self.multistream_overlap_gate):
enabled=self.multistream_overlap_gate):
# share_expert # share_expert
assert fc3_context.shared_experts is not None assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states) shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel` # NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ if (
and not shared_expert_dp_enabled(): moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()
):
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out) set_flash_common3_context(shared_out=shared_out)
@@ -325,24 +314,22 @@ class AscendFusedMoE(FusedMoE):
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor, routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias, e_score_correction_bias=self.e_score_correction_bias,
global_num_experts=self.global_num_experts) global_num_experts=self.global_num_experts,
)
if isinstance(forward_context.moe_comm_method, if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
AllGatherCommImpl): topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights, set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare( hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states, hidden_states=hidden_states,
router_logits=router_logits, router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled, replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp, enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type) quant_type=self.quant_type,
)
# Make sure the default stream waits for the gate stream to finish. # Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
@@ -375,39 +362,45 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance=enable_force_load_balance, enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy, log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num, global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask) mc2_mask=mc2_mask,
)
if self.dynamic_eplb: if self.dynamic_eplb:
expert_tokens = fused_experts_results.expert_tokens expert_tokens = fused_experts_results.expert_tokens
group_list_type = fused_experts_results.group_list_type group_list_type = fused_experts_results.group_list_type
assert expert_tokens is not None and group_list_type is not None, \ assert expert_tokens is not None and group_list_type is not None, (
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled." "expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
local_load = expert_tokens if group_list_type == 1 else \ )
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]]) local_load = (
expert_tokens
if group_list_type == 1
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
)
self.moe_load.add_(local_load) self.moe_load.add_(local_load)
routed_out = forward_context.moe_comm_method.finalize( routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out, hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results, reduce_results=self.reduce_results,
context_metadata=context_metadata) context_metadata=context_metadata,
)
if return_with_event: if return_with_event:
return FusedMoEResult( return FusedMoEResult(
routed_out=routed_out, routed_out=routed_out,
before_dispatch_evt=fused_experts_results.before_dispatch_evt, before_dispatch_evt=fused_experts_results.before_dispatch_evt,
before_combine_evt=fused_experts_results.before_combine_evt) before_combine_evt=fused_experts_results.before_combine_evt,
)
else: else:
# The vLLM FusedMoE forward_impl does not return events. # The vLLM FusedMoE forward_impl does not return events.
return routed_out return routed_out
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE): class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def __init__( def __init__(
self, self,
shared_experts: torch.nn.Module, shared_experts: torch.nn.Module,
gate: Optional[torch.nn.Module] = None, gate: torch.nn.Module | None = None,
use_overlapped: bool = True, use_overlapped: bool = True,
routed_input_transform: Optional[torch.nn.Module] = None, routed_input_transform: torch.nn.Module | None = None,
**kwargs, **kwargs,
): ):
AscendFusedMoE.__init__(self, **kwargs) AscendFusedMoE.__init__(self, **kwargs)
@@ -418,16 +411,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.use_overlapped = use_overlapped self.use_overlapped = use_overlapped
self.shared_expert_stream = None self.shared_expert_stream = None
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = \ self.multistream_overlap_shared_expert = (
ascend_config.multistream_overlap_shared_expert and \ ascend_config.multistream_overlap_shared_expert and self._shared_experts is not None
self._shared_experts is not None )
self.multistream_overlap_gate = \ self.multistream_overlap_gate = ascend_config.multistream_overlap_gate and self._shared_experts is not None
ascend_config.multistream_overlap_gate and \
self._shared_experts is not None
if enable_sp(): if enable_sp():
logger.info_once( logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.")
"Sequence parallelism is enabled, shared experts are replicated for best performance."
)
self._gate = gate self._gate = gate
@@ -447,20 +436,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
def _shared_experts_part1(self, hidden_states: torch.Tensor): def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj( shared_gate_up, _ = self._shared_experts.gate_up_proj(hidden_states) # type: ignore
hidden_states) # type: ignore
return shared_gate_up return shared_gate_up
def _shared_experts_part2(self, hidden_states: torch.Tensor, def _shared_experts_part2(self, hidden_states: torch.Tensor, shared_gate_up: torch.Tensor):
shared_gate_up: torch.Tensor): shared_act = self._shared_experts.act_fn(shared_gate_up) # type: ignore
shared_act = self._shared_experts.act_fn( shared_out, _ = self._shared_experts.down_proj(shared_act) # type: ignore
shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(
shared_act) # type: ignore
# Qwen3-Next specific gating mechanism # Qwen3-Next specific gating mechanism
if hasattr(self._shared_experts, "expert_gate") and \ if hasattr(self._shared_experts, "expert_gate") and self._shared_experts.expert_gate is not None:
self._shared_experts.expert_gate is not None:
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
shared_out = F.sigmoid(gate_out) * shared_out shared_out = F.sigmoid(gate_out) * shared_out
return shared_out return shared_out
@@ -468,9 +452,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def _validate_shared_expert_consistency(self): def _validate_shared_expert_consistency(self):
"""Validate that split shared expert computation matches integrated """Validate that split shared expert computation matches integrated
computation.""" computation."""
test_input = torch.rand( test_input = (
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype torch.rand(10, self.hidden_size, device="npu", dtype=self.moe_config.in_dtype) * 2 - 1
) * 2 - 1 # Random input for testing, scoped to [-1, 1] ) # Random input for testing, scoped to [-1, 1]
integrated_out = self._shared_experts(test_input) integrated_out = self._shared_experts(test_input)
part1_out = self._shared_experts_part1(test_input) part1_out = self._shared_experts_part1(test_input)
@@ -478,25 +462,19 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if not torch.allclose(integrated_out, split_out): if not torch.allclose(integrated_out, split_out):
diff = (integrated_out - split_out).abs() diff = (integrated_out - split_out).abs()
logger.error( logger.error("SharedFusedMoE shared experts split computation does not match the integrated computation.")
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.error(f"Max absolute difference: {diff.max().item()}") logger.error(f"Max absolute difference: {diff.max().item()}")
logger.error("Integrated output - sum: %s, norm: %s", logger.error(
integrated_out.sum().item(), "Integrated output - sum: %s, norm: %s", integrated_out.sum().item(), integrated_out.norm().item()
integrated_out.norm().item()) )
logger.error("Split output - sum: %s, norm: %s", logger.error("Split output - sum: %s, norm: %s", split_out.sum().item(), split_out.norm().item())
split_out.sum().item(),
split_out.norm().item())
raise ValueError( raise ValueError(
"SharedFusedMoE shared experts split computation does not " "SharedFusedMoE shared experts split computation does not match the integrated computation."
"match the integrated computation.") )
logger.info_once( logger.info_once("SharedFusedMoE shared experts split computation matches the integrated computation.")
"SharedFusedMoE shared experts split computation matches the "
"integrated computation.")
@property @property
def gate(self) -> Optional[torch.nn.Module]: def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None return self._gate if self.use_overlapped else None
@property @property
@@ -530,8 +508,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
) )
return shared_out, fused_out return shared_out, fused_out
def _forward_shared_experts(self, hidden_states: torch.Tensor, def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents):
fused_moe_evts: FusedMoEEvents):
if self._shared_experts is None: if self._shared_experts is None:
return None return None
@@ -539,11 +516,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if evt is not None: if evt is not None:
torch.npu.current_stream().wait_event(evt) torch.npu.current_stream().wait_event(evt)
with npu_stream_switch(shared_experts_calculation_stream(), with npu_stream_switch(shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert):
enabled=self.multistream_overlap_shared_expert):
# Ensure the shared experts wait for hidden_states to be ready. # Ensure the shared experts wait for hidden_states to be ready.
torch.npu.current_stream().wait_event( torch.npu.current_stream().wait_event(fused_moe_evts.before_routed_experts)
fused_moe_evts.before_routed_experts)
# Execute the gate projection and activation concurrently with the # Execute the gate projection and activation concurrently with the
# dispatch communication. # dispatch communication.
maybe_wait_event(fused_moe_evts.before_dispatch) maybe_wait_event(fused_moe_evts.before_dispatch)
@@ -556,20 +531,22 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# Make sure the default stream waits for the shared experts stream to # Make sure the default stream waits for the shared experts stream to
# finish. # finish.
if self.multistream_overlap_shared_expert: if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream( torch.npu.current_stream().wait_stream(shared_experts_calculation_stream())
shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of # NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel` # `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context() forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \ if (
and not shared_expert_dp_enabled(): moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()
):
shared_out = tensor_model_parallel_all_reduce(shared_out) shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out return shared_out
def forward_impl( # type: ignore[override] def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor): self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
set_flash_common3_context(shared_experts=self._shared_experts) set_flash_common3_context(shared_experts=self._shared_experts)
@@ -596,6 +573,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
before_routed_experts=before_routed_experts, before_routed_experts=before_routed_experts,
before_dispatch=fused_moe_results.before_dispatch_evt, before_dispatch=fused_moe_results.before_dispatch_evt,
before_combine=fused_moe_results.before_combine_evt, before_combine=fused_moe_results.before_combine_evt,
)) ),
)
return shared_out, routed_out return shared_out, routed_out

View File

@@ -17,7 +17,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from typing import Dict, Optional
import torch import torch
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
@@ -27,18 +26,24 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import ( from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize, PrepareAndFinalizeWithAll2All, PrepareAndFinalize,
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType) PrepareAndFinalizeWithAll2All,
PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2,
QuantType,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import ( from vllm_ascend.ops.fused_moe.token_dispatcher import (
MoETokenDispatcher, TokenDispatcherWithAll2AllV, MoETokenDispatcher,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2) TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {} _MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
def get_moe_comm_method( def get_moe_comm_method(moe_comm_type: MoECommType | None) -> MoECommMethod | None:
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]: return _MoECommMethods.get(moe_comm_type)
return _MoECommMethods.get(moe_comm_type, None)
def setup_moe_comm_method(moe_config): def setup_moe_comm_method(moe_config):
@@ -50,6 +55,7 @@ def setup_moe_comm_method(moe_config):
def set_gmmswigluquant_method(): def set_gmmswigluquant_method():
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
ascend_config = get_ascend_config() ascend_config = get_ascend_config()
return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant
@@ -84,51 +90,46 @@ class MoECommMethod(ABC):
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE, quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare( hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp, hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
replace_allreduce, quant_type) )
return hidden_states, router_logits, mc2_mask, context_metadata return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self, def finalize(
hidden_states: torch.Tensor, self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
reduce_results: bool, ) -> torch.Tensor:
context_metadata: Optional[dict] = None) -> torch.Tensor: hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
hidden_states = self.prepare_finalize.finalize(hidden_states,
reduce_results,
context_metadata)
return hidden_states return hidden_states
def fused_experts( def fused_experts(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor], w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor], w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False, use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
w1_scale: Optional[list[torch.Tensor]] = None, w1_scale: list[torch.Tensor] | None = None,
w2_scale: Optional[list[torch.Tensor]] = None, w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None, w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w1_offset: torch.Tensor | None = None,
w2_offset: Optional[torch.Tensor] = None, w2_offset: torch.Tensor | None = None,
# For load balance # For load balance
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
need_trans: bool = False, need_trans: bool = False,
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None, mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None): pertoken_scale: torch.Tensor | None = None,
):
# Check constraints # Check constraints
assert hidden_states.dtype in [ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
torch.float32, torch.float16, torch.bfloat16, torch.int8
]
moe_comm_method = get_forward_context().moe_comm_method moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context" assert moe_comm_method is not None, "Missing communication context"
@@ -143,13 +144,13 @@ class MoECommMethod(ABC):
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
expert_map=expert_map, expert_map=expert_map,
global_redundant_expert_num=self.moe_config. global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
global_redundant_expert_num,
mc2_mask=mc2_mask, mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8, with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb, dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale) pertoken_scale=pertoken_scale,
)
mlp_output = unified_apply_mlp( mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states, hidden_states=dispatch_results.hidden_states,
@@ -168,29 +169,29 @@ class MoECommMethod(ABC):
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16, with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
fusion=use_int8_w8a8 and self.use_fusion_ops, fusion=use_int8_w8a8 and self.use_fusion_ops,
need_trans=need_trans, need_trans=need_trans,
dynamic_eplb=dynamic_eplb) dynamic_eplb=dynamic_eplb,
)
before_combine_evt = torch.npu.current_stream().record_event() before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine( combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output, hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
context_metadata=dispatch_results.context_metadata) )
return FusedExpertsResult( return FusedExpertsResult(
routed_out=combine_results.routed_out, routed_out=combine_results.routed_out,
before_dispatch_evt=before_dispatch_evt, before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt, before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type, group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list) expert_tokens=dispatch_results.group_list,
)
@abstractmethod @abstractmethod
def _get_token_dispatcher(self) -> MoETokenDispatcher: def _get_token_dispatcher(self) -> MoETokenDispatcher:
raise NotImplementedError( raise NotImplementedError("_get_token_dispatcher function not implemented.")
"_get_token_dispatcher function not implemented.")
@abstractmethod @abstractmethod
def _get_prepare_finalize(self) -> PrepareAndFinalize: def _get_prepare_finalize(self) -> PrepareAndFinalize:
raise NotImplementedError( raise NotImplementedError("_get_prepare_finalize function not implemented.")
"_get_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod): class AllGatherCommImpl(MoECommMethod):
@@ -216,7 +217,8 @@ class AllGatherCommImpl(MoECommMethod):
return TokenDispatcherWithAllGather( return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token, top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts, num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts,
)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config) return PrepareAndFinalizeWithAllGather(self.moe_config)
@@ -253,7 +255,8 @@ class AlltoAllCommImpl(MoECommMethod):
return TokenDispatcherWithAll2AllV( return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token, top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts, num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts) num_local_experts=self.moe_config.num_local_experts,
)
def _get_prepare_finalize(self): def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config) return PrepareAndFinalizeWithAll2All(self.moe_config)
@@ -276,36 +279,36 @@ class FusedMC2CommImpl(MoECommMethod):
return PrepareAndFinalizeWithMC2(self.moe_config) return PrepareAndFinalizeWithMC2(self.moe_config)
def fused_experts( def fused_experts(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor], w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor], w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
activation: str = "silu", activation: str = "silu",
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False, use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False, use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False, use_int4_w4a16: bool = False,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
w1_scale: Optional[list[torch.Tensor]] = None, w1_scale: list[torch.Tensor] | None = None,
w2_scale: Optional[list[torch.Tensor]] = None, w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None, w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None, w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w1_offset: torch.Tensor | None = None,
w2_offset: Optional[torch.Tensor] = None, w2_offset: torch.Tensor | None = None,
# For load balance # For load balance
log2phy: torch.Tensor = None, log2phy: torch.Tensor = None,
need_trans: bool = False, need_trans: bool = False,
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None, mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None): pertoken_scale: torch.Tensor | None = None,
assert not ( ):
w1_scale is None or w2_scale is None assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \ assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
"token_dispatcher must be an instance of TokenDispatcherWithMC2." "token_dispatcher must be an instance of TokenDispatcherWithMC2."
)
# Apply log2phy if needed # Apply log2phy if needed
if log2phy is not None: if log2phy is not None:
@@ -346,10 +349,8 @@ class FusedMC2CommImpl(MoECommMethod):
ep_rank_size=self.token_dispatcher.ep_world_size, ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id, ep_rank_id=self.token_dispatcher.ep_rank_id,
moe_expert_num=self.moe_config.num_experts, moe_expert_num=self.moe_config.num_experts,
global_bs=self.token_dispatcher.global_bs) global_bs=self.token_dispatcher.global_bs,
)
else: else:
raise ValueError( raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}") return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
return FusedExpertsResult(routed_out=out,
group_list_type=group_list_type,
expert_tokens=expert_tokens)

View File

@@ -14,7 +14,6 @@
# limitations under the License. # limitations under the License.
# This file is a part of the vllm-ascend project. # This file is a part of the vllm-ascend project.
from typing import Optional
import torch import torch
import torch_npu import torch_npu
@@ -23,24 +22,22 @@ from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor, from vllm_ascend.utils import (
enable_custom_op, get_ascend_device_type, dispose_tensor,
get_weight_prefetch_method) enable_custom_op,
get_weight_prefetch_method,
)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
return fusion and dynamic_eplb and enable_custom_op() return fusion and dynamic_eplb and enable_custom_op()
def cumsum_group_list(group_list: torch.Tensor, def cumsum_group_list(
src_list_type: int, group_list: torch.Tensor, src_list_type: int, dst_list_type: int, active_num: int = 0, expert_num: int = 0
dst_list_type: int, ) -> torch.Tensor:
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
if src_list_type not in [0, 1, 2]: if src_list_type not in [0, 1, 2]:
raise ValueError( raise ValueError(f"group_list_type should be in [0, 1, 2], but received {src_list_type}")
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
)
if src_list_type == dst_list_type: if src_list_type == dst_list_type:
return group_list return group_list
@@ -53,10 +50,9 @@ def cumsum_group_list(group_list: torch.Tensor,
if src_list_type == 2 and dst_list_type == 0: if src_list_type == 2 and dst_list_type == 0:
experts = pad(group_list[:, 0], (1, 0)) experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0)) tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ), cumsum_group_list = torch.full(
fill_value=active_num, size=(expert_num,), fill_value=active_num, dtype=group_list.dtype, device=group_list.device
dtype=group_list.dtype, )
device=group_list.device)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])): for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start: if end > start:
@@ -65,30 +61,32 @@ def cumsum_group_list(group_list: torch.Tensor,
return cumsum_group_list return cumsum_group_list
raise NotImplementedError( raise NotImplementedError(
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. " f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
"This feature is under development.") "This feature is under development."
)
def quant_apply_mlp(hidden_states: torch.Tensor, def quant_apply_mlp(
w1: list[torch.Tensor], hidden_states: torch.Tensor,
w1_scale: list[torch.Tensor], w1: list[torch.Tensor],
w2: list[torch.Tensor], w1_scale: list[torch.Tensor],
w2_scale: list[torch.Tensor], w2: list[torch.Tensor],
group_list: torch.Tensor, w2_scale: list[torch.Tensor],
group_list_type: int = 1, group_list: torch.Tensor,
dynamic_scale: torch.Tensor = None, group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None, dynamic_scale: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None, w1_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w2_scale_bias: torch.Tensor = None,
w2_offset: Optional[torch.Tensor] = None, w1_offset: torch.Tensor | None = None,
fusion: bool = False, w2_offset: torch.Tensor | None = None,
dynamic_eplb: bool = False) -> torch.Tensor: fusion: bool = False,
dynamic_eplb: bool = False,
) -> torch.Tensor:
if w1_offset is not None: if w1_offset is not None:
unquantized_hidden_states = hidden_states unquantized_hidden_states = hidden_states
quantized_hidden_states = None quantized_hidden_states = None
elif dynamic_scale is None: elif dynamic_scale is None:
unquantized_hidden_states = hidden_states unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states)
# Dispose the original unquantized hidden states # Dispose the original unquantized hidden states
# to save npu memory because they're no longer used. # to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states) dispose_tensor(unquantized_hidden_states)
@@ -103,22 +101,18 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight_prefetch_method = get_weight_prefetch_method() weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method: if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess( weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2 is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2: if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu # gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = ( hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
torch.ops._C_ascend. x=hidden_states,
grouped_matmul_swiglu_quant_weight_nz_tensor_list( weight=w1,
x=hidden_states, weight_scale=w1_scale,
weight=w1, x_scale=pertoken_scale,
weight_scale=w1_scale, group_list=cumsum_group_list(group_list, group_list_type, 0),
x_scale=pertoken_scale, )
group_list=cumsum_group_list(group_list, group_list_type,
0),
))
elif fusion and not dynamic_eplb: elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu # gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
@@ -126,7 +120,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight=w1[0], weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type, 0), group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0], weight_scale=w1_scale[0],
x_scale=pertoken_scale) x_scale=pertoken_scale,
)
if quantized_hidden_states is not None: if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states) dispose_tensor(quantized_hidden_states)
else: else:
@@ -140,7 +135,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=torch.int32)[0] output_dtype=torch.int32,
)[0]
if quantized_hidden_states is not None: if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states) dispose_tensor(quantized_hidden_states)
# act_fn: swiglu # act_fn: swiglu
@@ -165,7 +161,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=w2_scale[0].dtype)[0] output_dtype=w2_scale[0].dtype,
)[0]
elif w1_offset is not None: elif w1_offset is not None:
# gmm1: gate_up_proj # gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
@@ -177,7 +174,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=_output_dtype)[0] output_dtype=_output_dtype,
)[0]
dispose_tensor(unquantized_hidden_states) dispose_tensor(unquantized_hidden_states)
# act_fn: swiglu # act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states = torch_npu.npu_swiglu(hidden_states)
@@ -191,13 +189,12 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=_output_dtype)[0] output_dtype=_output_dtype,
)[0]
else: else:
if w1_scale_bias is not None: if w1_scale_bias is not None:
if group_list_type == 0: if group_list_type == 0:
group_list = torch.cat( group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)])
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list_type = 1 group_list_type = 1
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
bias2 = [w2_scale_bias] bias2 = [w2_scale_bias]
@@ -206,17 +203,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb): if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu # gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = ( hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
torch.ops._C_ascend. x=hidden_states,
grouped_matmul_swiglu_quant_weight_nz_tensor_list( weight=w1,
x=hidden_states, weight_scale=w1_scale,
weight=w1, x_scale=pertoken_scale,
weight_scale=w1_scale, group_list=cumsum_group_list(group_list, group_list_type, 0),
x_scale=pertoken_scale, bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type, )
0),
bias=bias1,
))
elif fusion and not dynamic_eplb: elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu # gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant( hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
@@ -225,7 +219,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias=bias1, bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type, 0), group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0], weight_scale=w1_scale[0],
x_scale=pertoken_scale) x_scale=pertoken_scale,
)
if quantized_hidden_states is not None: if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states) dispose_tensor(quantized_hidden_states)
else: else:
@@ -241,21 +236,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=_output_dtype)[0] output_dtype=_output_dtype,
)[0]
if quantized_hidden_states is not None: if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states) dispose_tensor(quantized_hidden_states)
# act_fn: swiglu # act_fn: swiglu
if HAS_TRITON: if HAS_TRITON:
from vllm_ascend.ops.triton.activation.swiglu_quant import \ from vllm_ascend.ops.triton.activation.swiglu_quant import swiglu_quant
swiglu_quant
hidden_states, swiglu_out_scale = swiglu_quant( hidden_states, swiglu_out_scale = swiglu_quant(
hidden_states, hidden_states, group_list=group_list, group_list_type=group_list_type
group_list=group_list, )
group_list_type=group_list_type)
else: else:
hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant( hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states)
# gmm2: down_proj # gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul( hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states], x=[hidden_states],
@@ -267,18 +261,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type, group_list_type=group_list_type,
group_type=0, group_type=0,
group_list=group_list, group_list=group_list,
output_dtype=_output_dtype)[0] output_dtype=_output_dtype,
)[0]
return hidden_states return hidden_states
def unquant_apply_mlp(hidden_states: torch.Tensor, def unquant_apply_mlp(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
group_list: torch.Tensor, w2: torch.Tensor,
group_list_type: int = 1, group_list: torch.Tensor,
topk_scales: Optional[torch.Tensor] = None, group_list_type: int = 1,
need_trans: bool = True) -> torch.Tensor: topk_scales: torch.Tensor | None = None,
need_trans: bool = True,
) -> torch.Tensor:
if need_trans: if need_trans:
w1 = w1.transpose(1, 2) w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2) w2 = w2.transpose(1, 2)
@@ -307,44 +303,50 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
return hidden_states return hidden_states
def unified_apply_mlp(hidden_states: torch.Tensor, def unified_apply_mlp(
w1: torch.Tensor | list[torch.Tensor], hidden_states: torch.Tensor,
w2: torch.Tensor | list[torch.Tensor], w1: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor, w2: torch.Tensor | list[torch.Tensor],
w1_scale: Optional[list[torch.Tensor]] = None, group_list: torch.Tensor,
w2_scale: Optional[list[torch.Tensor]] = None, w1_scale: list[torch.Tensor] | None = None,
dynamic_scale: torch.Tensor = None, w2_scale: list[torch.Tensor] | None = None,
group_list_type: int = 1, dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None, group_list_type: int = 1,
w2_scale_bias: torch.Tensor = None, w1_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None, w2_scale_bias: torch.Tensor = None,
w2_offset: Optional[torch.Tensor] = None, w1_offset: torch.Tensor | None = None,
topk_scales: Optional[torch.Tensor] = None, w2_offset: torch.Tensor | None = None,
with_quant: bool = False, topk_scales: torch.Tensor | None = None,
fusion: bool = False, with_quant: bool = False,
need_trans: bool = True, fusion: bool = False,
dynamic_eplb: bool = False) -> torch.Tensor: need_trans: bool = True,
dynamic_eplb: bool = False,
) -> torch.Tensor:
if with_quant: if with_quant:
assert w1_scale is not None and w2_scale is not None assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(hidden_states=hidden_states, return quant_apply_mlp(
w1=w1, hidden_states=hidden_states,
w1_scale=w1_scale, w1=w1,
w2=w2, w1_scale=w1_scale,
w2_scale=w2_scale, w2=w2,
group_list=group_list, w2_scale=w2_scale,
dynamic_scale=dynamic_scale, group_list=group_list,
group_list_type=group_list_type, dynamic_scale=dynamic_scale,
w1_scale_bias=w1_scale_bias, group_list_type=group_list_type,
w2_scale_bias=w2_scale_bias, w1_scale_bias=w1_scale_bias,
w1_offset=w1_offset, w2_scale_bias=w2_scale_bias,
w2_offset=w2_offset, w1_offset=w1_offset,
fusion=fusion, w2_offset=w2_offset,
dynamic_eplb=dynamic_eplb) fusion=fusion,
dynamic_eplb=dynamic_eplb,
)
else: else:
return unquant_apply_mlp(hidden_states=hidden_states, return unquant_apply_mlp(
w1=w1, hidden_states=hidden_states,
w2=w2, w1=w1,
group_list=group_list, w2=w2,
group_list_type=group_list_type, group_list=group_list,
topk_scales=topk_scales, group_list_type=group_list_type,
need_trans=need_trans) topk_scales=topk_scales,
need_trans=need_trans,
)

View File

@@ -16,22 +16,23 @@
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from typing import Optional
import torch import torch
import torch.distributed as dist import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch_npu import torch_npu
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank, get_dp_group,
get_tensor_model_parallel_world_size) get_pcp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.utils import (enable_sp, npu_stream_switch, from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
prefill_context_parallel_enable)
class QuantType(Enum): class QuantType(Enum):
@@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC):
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info, moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings. sizes, ranks, and communication settings.
""" """
quant_stream: Optional[torch.npu.Stream] = None
quant_stream: torch.npu.Stream | None = None
def __init__(self, moe_config: FusedMoEConfig): def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config self.moe_config = moe_config
@@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC):
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
""" """
Prepare tensors before MoE computation. May involve: Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries - Padding to align communication boundaries
@@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC):
""" """
raise NotImplementedError("Prepare not implemented.") raise NotImplementedError("Prepare not implemented.")
def finalize(self, def finalize(
hidden_states: torch.Tensor, self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
reduce_results: bool, ) -> torch.Tensor:
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalize MoE output. May involve: Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks - Gathering sliced tensors across TP ranks
@@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size. 1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic) pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0: if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states, hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
(0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape padded_hidden_states_shape = hidden_states.shape
if self.tp_size > 1: if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states, split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
self.tp_size, split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
context_metadata = { context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
"padded_hidden_states_shape": padded_hidden_states_shape
}
return hidden_states, router_logits, None, context_metadata return hidden_states, router_logits, None, context_metadata
def finalize(self, def finalize(
hidden_states: torch.Tensor, self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
reduce_results: bool, ) -> torch.Tensor:
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor. 1. If TP > 1, all-gather slices to reconstruct full tensor.
@@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
# may share memory with original hidden_states. Since shared # may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause # experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data. # in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata[ padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
"padded_hidden_states_shape"]
gathered_hidden_states = torch.empty( gathered_hidden_states = torch.empty(
padded_hidden_states_shape, padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
device=hidden_states.device, )
dtype=hidden_states.dtype) split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0)
split_hidden_states = torch.tensor_split( dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group)
gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
hidden_states = gathered_hidden_states hidden_states = gathered_hidden_states
if self.num_tokens < hidden_states.shape[0]: if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens] hidden_states = hidden_states[: self.num_tokens]
return hidden_states return hidden_states
@@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context. 1. Fetch `mc2_mask` and target padding length from forward context.
@@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
# Pad if necessary (unless shared expert DP is enabled) # Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp: if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states, hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
(0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape padded_hidden_states_shape = hidden_states.shape
# Slice across TP ranks # Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp: if self.tp_size > 1 and not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states, split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
self.tp_size, split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
hidden_states = split_hidden_states[self.tp_rank] hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank] router_logits = split_router_logits[self.tp_rank]
@@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
AllGather hidden_states and router_logits to form global tensors. AllGather hidden_states and router_logits to form global tensors.
@@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None) Tuple of (global_hidden_states, global_router_logits, None)
""" """
if enable_sp(): if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits, return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits, return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
enable_shared_expert_dp,
replace_allreduce)
def _prepare_with_ep_group( def _prepare_with_ep_group(
self, self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
hidden_states: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
router_logits: torch.Tensor,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
pertoken_scale = None pertoken_scale = None
if quant_type == QuantType.W8A8: if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant( hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states)
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
assert PrepareAndFinalize.quant_stream is not None assert PrepareAndFinalize.quant_stream is not None
PrepareAndFinalize.quant_stream.wait_stream( PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream())
torch.npu.current_stream()) with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate):
with npu_stream_switch(PrepareAndFinalize.quant_stream, hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states)
enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
hidden_states)
else: else:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True)
hidden_states, True, True) router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
if pertoken_scale is not None: if pertoken_scale is not None:
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad( pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True)
pertoken_scale, True, True)
if self.multistream_overlap_gate: if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream( torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
PrepareAndFinalize.quant_stream)
if pertoken_scale is not None: if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None return (hidden_states, pertoken_scale), router_logits, None, None
@@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor, router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False, enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False, replace_allreduce: bool = False,
quant_type=QuantType.NONE quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
Optional[torch.Tensor]]:
""" """
Preparation steps: Preparation steps:
1. Fetch max token count across DP group from forward context. 1. Fetch max token count across DP group from forward context.
@@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
self.num_tokens = hidden_states.shape[0] self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0: if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states, hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
(0, 0, 0, pad_size)) router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
# All-gather across DP group # All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather( hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0)
hidden_states, 0) router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather( hidden_states = get_pcp_group().all_gather(
@@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return hidden_states, router_logits, None, None return hidden_states, router_logits, None, None
def finalize(self, def finalize(
hidden_states: torch.Tensor, self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
reduce_results: bool, ) -> torch.Tensor:
context_metadata: Optional[dict] = None) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
Reduce Scatter hidden states. Reduce Scatter hidden states.
@@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return self._finalize_with_dp_group(hidden_states, reduce_results) return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self, def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states: torch.Tensor) -> torch.Tensor:
""" """
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled: Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to 1. Reduce_results is False usually happens when models have shared experts and need to
@@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter 2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe. here, then skip allreudce in FusedMoe.
""" """
hidden_states = torch.ops.vllm.maybe_pad_and_reduce( hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True)
hidden_states, True)
return hidden_states return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor, def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor:
reduce_results: bool) -> torch.Tensor:
""" """
Finalization steps: Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group. 1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
@@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
""" """
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp: if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0) hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens] hidden_states = hidden_states[: self.num_tokens]
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1: if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(hidden_states, hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0)
dim=0)
return hidden_states return hidden_states

View File

@@ -22,7 +22,6 @@
# limitations under the License. # limitations under the License.
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional
import torch import torch
import torch_npu import torch_npu
@@ -30,10 +29,8 @@ from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import ( from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
async_all_to_all, gather_from_sequence_parallel_region) from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
is_hierarchical_communication_enabled)
@dataclass @dataclass
@@ -52,7 +49,6 @@ class TokenCombineResult:
class MoETokenDispatcher(ABC): class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
""" """
Initialize the MoE Token Dispatcher. Initialize the MoE Token Dispatcher.
@@ -79,26 +75,24 @@ class MoETokenDispatcher(ABC):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
topk_weights: torch.Tensor, topk_weights: torch.Tensor,
topk_ids: torch.Tensor, topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0, global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None, mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False, apply_router_weight_on_input: bool = False,
with_quant: bool = False, with_quant: bool = False,
dynamic_eplb: bool = False, dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None, pertoken_scale: torch.Tensor | None = None,
) -> TokenDispatchResult: ) -> TokenDispatchResult:
raise NotImplementedError("Dispatch function not implemented.") raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod @abstractmethod
def token_combine(self, def token_combine(
hidden_states: torch.Tensor, self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
context_metadata: dict, ) -> TokenCombineResult:
bias: torch.Tensor | None = None) -> TokenCombineResult:
raise NotImplementedError("Combine function not implemented.") raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher): class TokenDispatcherWithMC2(MoETokenDispatcher):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
device_group = get_mc2_group().device_group device_group = get_mc2_group().device_group
@@ -108,10 +102,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu, self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
"npu_moe_distribute_dispatch_v2") self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
self.need_extra_args = (
get_ascend_device_type() == AscendDeviceType.A3)
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and # NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly # HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
@@ -126,10 +118,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
compilation_config = vllm_config.compilation_config compilation_config = vllm_config.compilation_config
speculative_config = vllm_config.speculative_config speculative_config = vllm_config.speculative_config
tp_size = vllm_config.parallel_config.tensor_parallel_size tp_size = vllm_config.parallel_config.tensor_parallel_size
uniform_decode_query_len = 1 if not speculative_config else \ uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens
1 + speculative_config.num_speculative_tokens decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs) max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
if compilation_config.cudagraph_capture_sizes: if compilation_config.cudagraph_capture_sizes:
max_num_tokens = compilation_config.max_cudagraph_capture_size max_num_tokens = compilation_config.max_cudagraph_capture_size
@@ -167,44 +157,56 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"ep_rank_id": self.ep_rank_id, "ep_rank_id": self.ep_rank_id,
} }
if self.need_extra_args: if self.need_extra_args:
stage1_kwargs.update({ stage1_kwargs.update(
"group_tp": self.moe_all_to_all_group_name, {
"tp_world_size": 1, "group_tp": self.moe_all_to_all_group_name,
"tp_rank_id": 0, "tp_world_size": 1,
}) "tp_rank_id": 0,
}
)
if self.need_expert_scale: if self.need_expert_scale:
stage1_kwargs.update({ stage1_kwargs.update(
"expert_scales": {
topk_weights.to(torch.float32), "expert_scales": topk_weights.to(torch.float32),
}) }
)
kwargs_mc2.update(stage1_kwargs) kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2 return kwargs_mc2
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
global_redundant_expert_num: int = 0, expert_map: torch.Tensor | None = None,
mc2_mask: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None,
with_quant: bool = False, apply_router_weight_on_input: bool = False,
dynamic_eplb: bool = False, with_quant: bool = False,
pertoken_scale: Optional[torch.Tensor] = None): dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant self.with_quant = with_quant
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights, kwargs_mc2 = self.get_dispatch_mc2_kwargs(
topk_ids, expert_map, hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
mc2_mask, )
global_redundant_expert_num) output = (
output = torch_npu.npu_moe_distribute_dispatch_v2( torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
**kwargs_mc2 if self.enable_dispatch_v2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch( else torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
**kwargs_mc2) )
# comm_stream.wait_stream(torch.npu.current_stream()) # comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \ (
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7] expand_x,
dynamic_scale,
assist_info_for_combine,
expert_token_nums,
ep_recv_counts,
tp_recv_counts,
expand_scales,
) = output[0:7]
context_metadata = { context_metadata = {
"topk_ids": topk_ids, "topk_ids": topk_ids,
@@ -213,18 +215,19 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"ep_recv_counts": ep_recv_counts, "ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts, "tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine, "assist_info_for_combine": assist_info_for_combine,
"expand_scales": expand_scales "expand_scales": expand_scales,
} }
group_list_type = 0 group_list_type = 0
return TokenDispatchResult(hidden_states=expand_x, return TokenDispatchResult(
dynamic_scale=dynamic_scale, hidden_states=expand_x,
group_list=expert_token_nums, dynamic_scale=dynamic_scale,
group_list_type=group_list_type, group_list=expert_token_nums,
context_metadata=context_metadata) group_list_type=group_list_type,
context_metadata=context_metadata,
)
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
context_metadata: dict):
expert_map = context_metadata["expert_map"] expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"] topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"] topk_weights = context_metadata["topk_weights"]
@@ -246,9 +249,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
} }
if self.with_quant: if self.with_quant:
tp_recv_counts = torch.empty(1, tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
dtype=torch.int32,
device=hidden_states.device)
stage3_kwargs = { stage3_kwargs = {
"ep_send_counts": ep_recv_counts, "ep_send_counts": ep_recv_counts,
@@ -264,12 +265,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
stage3_kwargs["expand_idx"] = assist_info_for_combine stage3_kwargs["expand_idx"] = assist_info_for_combine
if self.need_extra_args: if self.need_extra_args:
stage3_kwargs.update({ stage3_kwargs.update(
"tp_send_counts": tp_recv_counts, {
"group_tp": self.moe_all_to_all_group_name, "tp_send_counts": tp_recv_counts,
"tp_world_size": 1, "group_tp": self.moe_all_to_all_group_name,
"tp_rank_id": 0, "tp_world_size": 1,
}) "tp_rank_id": 0,
}
)
kwargs_mc2.update(stage3_kwargs) kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2 return kwargs_mc2
@@ -277,57 +280,58 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_combine(self, hidden_states, context_metadata, bias=None): def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
context_metadata) combined_output = (
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \ torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2) if self.enable_dispatch_v2
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
)
return TokenCombineResult(routed_out=combined_output, ) return TokenCombineResult(
routed_out=combined_output,
)
class TokenDispatcherWithAllGather(MoETokenDispatcher): class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs): def __init__(self, **kwargs):
super().__init__(**kwargs) super().__init__(**kwargs)
self.apply_router_weight_on_input = False self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens") self.max_num_tokens = kwargs.get("max_num_tokens")
num_experts_local = kwargs.get("num_local_experts", 0) num_experts_local = kwargs.get("num_local_experts", 0)
self.num_experts_local = num_experts_local.item() if torch.is_tensor( self.num_experts_local = (
num_experts_local) else int(num_experts_local) num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
)
self.original_shape = None self.original_shape = None
self.with_quant = False self.with_quant = False
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
global_redundant_expert_num: int = 0, expert_map: torch.Tensor | None = None,
mc2_mask: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None,
with_quant: bool = False, apply_router_weight_on_input: bool = False,
dynamic_eplb: bool = False, with_quant: bool = False,
pertoken_scale: Optional[torch.Tensor] = None): dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant self.with_quant = with_quant
self.original_shape = hidden_states.shape self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel() num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input: if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
), "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape _, topk = topk_weights.shape
assert ( assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
topk == 1 hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
if expert_map is not None: if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num global_num_experts = len(expert_map) + global_redundant_expert_num
mask = (expert_map[topk_ids] != -1) mask = expert_map[topk_ids] != -1
topk_weights = topk_weights * mask topk_weights = topk_weights * mask
first_expert_idx = get_ep_group( first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local
).rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local last_expert_idx = first_expert_idx + self.num_experts_local
else: else:
first_expert_idx = 0 first_expert_idx = 0
@@ -344,15 +348,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_tokens_num_type=1, expert_tokens_num_type=1,
expert_tokens_num_flag=True, expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx], active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1 quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
if self.with_quant and pertoken_scale is None else -1, )
)) )
expert_tokens = expert_tokens.to(torch.int64) expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode group_list_type = 1 # `count` mode
context_metadata = { context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
return TokenDispatchResult( return TokenDispatchResult(
hidden_states=sorted_hidden_states, hidden_states=sorted_hidden_states,
@@ -367,7 +368,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
final_hidden_states = torch_npu.npu_moe_token_unpermute( final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states, permuted_tokens=hidden_states,
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]), sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"]) probs=context_metadata["topk_weights"],
)
if len(self.original_shape) == 3: if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape) final_hidden_states = final_hidden_states.view(self.original_shape)
@@ -398,35 +400,33 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
device=torch.npu.current_device(), device=torch.npu.current_device(),
) )
local_expert_indices_offset = (self.ep_rank * self.num_local_experts) local_expert_indices_offset = self.ep_rank * self.num_local_experts
self.local_expert_indices = [ self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]
local_expert_indices_offset + i assert len(self.local_expert_indices) == self.num_local_experts, "Invalid local expert indices"
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1): for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] == assert self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1, (
self.local_expert_indices[i + 1] - "local_expert_indices must be continuous"
1), "local_expert_indices must be continuous" )
# TODO: Try local_rank = ep_group.rank_in_group # TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=self.ep_group) local_rank = torch.distributed.get_rank(group=self.ep_group)
backend = self.ep_group._get_backend(torch.device("npu")) backend = self.ep_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank) self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
def token_dispatch(self, def token_dispatch(
hidden_states: torch.Tensor, self,
topk_weights: torch.Tensor, hidden_states: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
expert_map: Optional[torch.Tensor] = None, topk_ids: torch.Tensor,
global_redundant_expert_num: int = 0, expert_map: torch.Tensor | None = None,
mc2_mask: Optional[torch.Tensor] = None, global_redundant_expert_num: int = 0,
apply_router_weight_on_input: bool = False, mc2_mask: torch.Tensor | None = None,
with_quant: bool = False, apply_router_weight_on_input: bool = False,
dynamic_eplb: bool = False, with_quant: bool = False,
pertoken_scale: Optional[torch.Tensor] = None): dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant self.with_quant = with_quant
self.hidden_shape = hidden_states.shape self.hidden_shape = hidden_states.shape
@@ -442,35 +442,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
dynamic_scale_after_all2all = None dynamic_scale_after_all2all = None
if self.with_quant: if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant( permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all( _, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, output_splits, input_splits, self.ep_group) dynamic_scale, output_splits, input_splits, self.ep_group
)
permute2_ep_all_to_all_handle.wait() permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0) dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all( _, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens, output_splits, input_splits, permutated_local_input_tokens, output_splits, input_splits, self.ep_group
self.ep_group) )
permute1_ep_all_to_all_handle.wait() permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0) permutated_local_input_tokens.untyped_storage().resize_(0)
# Postprocess # Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess( global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
global_input_tokens, dynamic_scale_after_all2all, self._dispatch_postprocess(
global_input_tokens_local_experts_indices) global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
)
)
context_metadata = { context_metadata = {
"input_splits": "input_splits": input_splits,
input_splits, "output_splits": output_splits,
"output_splits": "topk_weights": topk_weights,
output_splits, "reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
"topk_weights": "reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
} }
return TokenDispatchResult( return TokenDispatchResult(
@@ -485,8 +482,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher." assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata # 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states, hidden_states = self._combine_preprocess(hidden_states, context_metadata)
context_metadata)
# 2. AllToAll # 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all( _, permutated_local_input_tokens, handle = async_all_to_all(
@@ -499,8 +495,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
hidden_states.untyped_storage().resize_(0) hidden_states.untyped_storage().resize_(0)
# 3. Postprocess using metadata # 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens, output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
context_metadata)
return TokenCombineResult(routed_out=output) return TokenCombineResult(routed_out=output)
@@ -534,42 +529,39 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
) )
def _preprocess(self, topk_ids: torch.Tensor): def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids, num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
bins=self.num_experts,
min=0,
max=self.num_experts)
ep_size = self.ep_size ep_size = self.ep_size
self.num_out_tokens = topk_ids.numel() self.num_out_tokens = topk_ids.numel()
input_splits = (num_local_tokens_per_expert.reshape( input_splits = (
ep_size, num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
self.num_local_experts).sum(axis=1).to(torch.device("cpu"), .sum(axis=1)
non_blocking=True).numpy()) .to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_global_tokens_per_expert = gather_from_sequence_parallel_region( num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert, num_local_tokens_per_expert, group=self.ep_group
group=self.ep_group).reshape(ep_size, self.num_experts) ).reshape(ep_size, self.num_experts)
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[ num_global_tokens_per_local_expert = num_global_tokens_per_expert[
0]:self.local_expert_indices[-1] + 1] :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
if num_global_tokens_per_local_expert is None: if num_global_tokens_per_local_expert is None:
raise ValueError( raise ValueError("num_global_tokens_per_local_expert must be set before sum.")
"num_global_tokens_per_local_expert must be set before sum.")
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to( output_splits = (
torch.device("cpu"), non_blocking=True).numpy()) num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum( )
axis=0) num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(axis=0)
global_input_tokens_local_experts_indices = None global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1: if self.num_local_experts > 1:
if num_global_tokens_per_local_expert is None: if num_global_tokens_per_local_expert is None:
raise ValueError( raise ValueError("num_global_tokens_per_local_expert must be set before operations.")
"num_global_tokens_per_local_expert must be set before operations."
)
global_input_tokens_local_experts_indices = torch.repeat_interleave( global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank, self.expert_ids_per_ep_rank, num_global_tokens_per_local_expert.ravel()
num_global_tokens_per_local_expert.ravel()) )
else: else:
torch.npu.synchronize() torch.npu.synchronize()
@@ -581,45 +573,41 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
global_input_tokens_local_experts_indices, global_input_tokens_local_experts_indices,
) )
def _dispatch_postprocess(self, global_input_tokens, def _dispatch_postprocess(
dynamic_scale_after_all2all, self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
global_input_tokens_local_experts_indices): ):
# Early return if no local experts or no tokens # Early return if no local experts or no tokens
if self.num_local_experts <= 1: if self.num_local_experts <= 1:
return global_input_tokens, dynamic_scale_after_all2all, None return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case # Handle quantized case
if self.with_quant: if self.with_quant:
assert global_input_tokens_local_experts_indices is not None, \ assert global_input_tokens_local_experts_indices is not None, (
"global_input_tokens_local_experts_indices must be provided" "global_input_tokens_local_experts_indices must be provided"
)
dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute( dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute(
dynamic_scale_after_all2all.unsqueeze(-1), dynamic_scale_after_all2all.unsqueeze(-1), global_input_tokens_local_experts_indices
global_input_tokens_local_experts_indices) )
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze( dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(-1)
-1)
# Non-quantized case # Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute( global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices) global_input_tokens, global_input_tokens_local_experts_indices
)
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor, def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input # Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1: if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[ rev_global = context_metadata["reversed_global_input_permutation_mapping"]
"reversed_global_input_permutation_mapping"] hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, rev_global)
return hidden_states return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output # Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute( output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens, permuted_tokens=permutated_local_input_tokens,
sorted_indices=context_metadata[ sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
"reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"], probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute, restore_shape=self.hidden_shape_before_permute,
) )