[Lint]Style: Convert vllm-ascend/ to ruff format(Batch #11) (#6176)

### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.14.0
- vLLM main:
d68209402d

Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
SILONG ZENG
2026-02-06 15:28:49 +08:00
committed by GitHub
parent 4fb3d5e1b2
commit 65b7f716e6
8 changed files with 694 additions and 784 deletions

View File

@@ -62,9 +62,6 @@ exclude = [
"vllm_ascend/worker/v2/**",
"vllm_ascend/worker/npu_input_batch.py",
"vllm_ascend/ops/rotary_embedding.py",
# (11)
"vllm_ascend/ops/fused_moe/**",
]
[tool.ruff.lint]

View File

@@ -23,11 +23,7 @@ import torch_npu
COMM_STREAM = None
def async_all_to_all(input_,
output_split_sizes,
input_split_sizes,
group,
event=None):
def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
if output_split_sizes is None:
# Equal split (all2all)
a2a_out = torch.empty_like(input_)
@@ -43,8 +39,7 @@ def async_all_to_all(input_,
# multi stream wait event
global COMM_STREAM
if COMM_STREAM is None:
COMM_STREAM = torch_npu.npu.Stream(
device=torch.npu.current_device())
COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
with torch_npu.npu.stream(COMM_STREAM):
event.wait()
handle = dist.all_to_all_single(
@@ -53,14 +48,17 @@ def async_all_to_all(input_,
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
async_op=True,
)
else:
handle = dist.all_to_all_single(a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True)
handle = dist.all_to_all_single(
a2a_out,
input_.contiguous(),
output_split_sizes=output_split_sizes,
input_split_sizes=input_split_sizes,
group=group,
async_op=True,
)
return input_, a2a_out, handle
@@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None):
if output_split_sizes is None:
dim_size[0] = dim_size[0] * world_size
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output,
input_.contiguous(),
group=group)
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group)
else:
dim_size[0] = sum(output_split_sizes)
output = torch.empty(dim_size,
dtype=input_.dtype,
device=torch.npu.current_device())
output_tensor_list = list(
torch.split(output, output_split_sizes, dim=0))
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
torch.distributed.all_gather(output_tensor_list, input_, group=group)
return output
@@ -110,4 +101,4 @@ def gather_from_sequence_parallel_region(
output_split_sizes=None,
):
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
return _gather_along_first_dim(input_, group, output_split_sizes)
return _gather_along_first_dim(input_, group, output_split_sizes)

View File

@@ -14,26 +14,28 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Optional
from collections.abc import Callable
import torch
from vllm_ascend.utils import get_weight_prefetch_method
def select_experts(hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
indices_type: Optional[torch.dtype] = None,
global_num_experts: int = -1):
def select_experts(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
e_score_correction_bias: torch.Tensor | None = None,
indices_type: torch.dtype | None = None,
global_num_experts: int = -1,
):
"""
Fused experts with select experts.
@@ -58,8 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
# prefetch w1_w3_proj.weight preprocess
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
hidden_states, "gate_up")
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
hidden_states=hidden_states,
top_k=top_k,
@@ -67,7 +68,8 @@ def select_experts(hidden_states: torch.Tensor,
topk_group=topk_group,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
custom_routing_function=custom_routing_function)
custom_routing_function=custom_routing_function,
)
if is_support_npu_moe_gating_top_k:
topk_weights, topk_ids = _select_experts_with_fusion_ops(
@@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor,
num_expert_group=num_expert_group,
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
else:
topk_weights, topk_ids = _native_select_experts(
hidden_states=hidden_states,
@@ -100,14 +103,15 @@ def select_experts(hidden_states: torch.Tensor,
def check_npu_moe_gating_top_k(
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
scoring_func: str = "softmax",
custom_routing_function: Optional[Callable] = None):
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
hidden_states: torch.Tensor,
top_k: int,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
scoring_func: str = "softmax",
custom_routing_function: Callable | None = None,
):
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
return False
if custom_routing_function is not None:
return False
@@ -115,39 +119,39 @@ def check_npu_moe_gating_top_k(
return False
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
if not (
num_expert_group > 0
and hidden_states.shape[-1] % num_expert_group == 0
and hidden_states.shape[-1] // num_expert_group > 2
):
return False
if topk_group < 1 or topk_group > num_expert_group:
return False
if top_k < 1 or \
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
return False
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
return False
return True
def _native_grouped_topk(
topk_weights: torch.Tensor,
num_expert_group: Optional[int],
topk_group: Optional[int],
num_expert_group: int | None,
topk_group: int | None,
):
topk_group = 0 if topk_group is None else topk_group
num_expert_group = 0 if num_expert_group is None else num_expert_group
num_token = topk_weights.shape[0]
grouped_weights = topk_weights.view(num_token, num_expert_group,
-1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
k=topk_group,
dim=-1,
sorted=False)[1]
grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
topk_group_mask = torch.zeros_like(grouped_weights)
topk_group_mask.scatter_(1, topk_group_indices, 1)
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
num_token, num_expert_group,
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
topk_weight_mask = (
topk_group_mask.unsqueeze(-1)
.expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
.reshape(num_token, -1)
)
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
return topk_weights
@@ -163,9 +167,13 @@ def _renormalize_topk_weights(
def _select_expert_use_group_topk(
topk_weights: torch.Tensor, topk_group: Optional[int],
renormalize: bool, top_k: int, num_expert_group: Optional[int],
e_score_correction_bias: Optional[torch.Tensor]):
topk_weights: torch.Tensor,
topk_group: int | None,
renormalize: bool,
top_k: int,
num_expert_group: int | None,
e_score_correction_bias: torch.Tensor | None,
):
assert topk_group is not None
assert num_expert_group is not None
@@ -177,47 +185,38 @@ def _select_expert_use_group_topk(
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
topk_group)
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
# TODO bfloat16 is not supported in torch.topk with ge graph.
if e_score_correction_bias is not None:
topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)[1]
topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
# Use original unbiased scores for the routing weights
topk_weights = original_weights.gather(1, topk_ids)
else:
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
k=top_k,
dim=-1,
sorted=False)
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
topk_ids = topk_ids.to(torch.int32)
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
return topk_weights, topk_ids
def _select_experts_with_fusion_ops(
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: Optional[torch.Tensor],
topk_group: Optional[int],
num_expert_group: Optional[int],
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1):
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
e_score_correction_bias: torch.Tensor | None,
topk_group: int | None,
num_expert_group: int | None,
scoring_func: str = "softmax",
routed_scaling_factor=1.0,
global_num_experts: int = -1,
):
topk_group = topk_group if topk_group is not None else 1
num_expert_group = num_expert_group if num_expert_group is not None else 1
renorm = int(renormalize)
norm_type = 0 if scoring_func == "softmax" else 1
if e_score_correction_bias is not None and \
e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(
router_logits.dtype)
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
router_logits,
k=top_k,
@@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops(
norm_type=norm_type, # 0: softmax; 1: sigmoid
out_flag=False,
routed_scaling_factor=routed_scaling_factor,
eps=float(1e-20),
eps=1e-20,
bias_opt=e_score_correction_bias,
)
@@ -241,12 +240,12 @@ def _native_select_experts(
top_k: int,
use_grouped_topk: bool,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: Optional[torch.Tensor] = None
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: torch.Tensor | None = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Select top-k experts based on router logits.
@@ -285,7 +284,8 @@ def _native_select_experts(
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
)
if custom_routing_function is not None:
topk_weights, topk_ids = custom_routing_function(
@@ -293,7 +293,8 @@ def _native_select_experts(
gating_output=router_logits,
topk=top_k,
renormalize=renormalize,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
# Required by npu_moe_init_routing
topk_ids = topk_ids.to(torch.int32)
return topk_weights, topk_ids
@@ -318,8 +319,7 @@ def zero_experts_compute(
if zero_expert_type == "identity":
zero_expert_mask = expert_indices < num_experts
zero_expert_scales = expert_scales.clone()
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
zero_expert_scales)
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)
hidden_states = hidden_states.unsqueeze(1)
zero_expert_scales = zero_expert_scales.unsqueeze(2)

View File

@@ -14,40 +14,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import wraps
from typing import Callable, Optional
import torch
import torch.nn.functional as F
from vllm.config import get_current_vllm_config
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
tensor_model_parallel_all_reduce)
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
from vllm.forward_context import get_forward_context
from vllm.logger import logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
from vllm.model_executor.layers.fused_moe.layer import (
FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map)
from vllm.model_executor.layers.fused_moe.shared_fused_moe import \
SharedFusedMoE
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
set_flash_common3_context)
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
zero_experts_compute)
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
FusedExpertsResult,
setup_moe_comm_method)
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
get_ascend_device_type, maybe_trans_nz,
npu_stream_switch, shared_expert_dp_enabled,
shared_experts_calculation_stream,
vllm_version_is)
from vllm_ascend.utils import (
enable_sp,
maybe_trans_nz,
npu_stream_switch,
shared_expert_dp_enabled,
shared_experts_calculation_stream,
vllm_version_is,
)
@dataclass
class FusedMoEResult:
@@ -64,46 +61,43 @@ class FusedMoEEvents:
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
def __init__(self, moe: FusedMoEConfig = None):
super().__init__(moe=moe)
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
def process_weights_after_loading(self, layer):
super(UnquantizedFusedMoEMethod,
self).process_weights_after_loading(layer)
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
1, 2).contiguous()
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous()
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
1, 2).contiguous()
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: Optional[torch.Tensor] = None,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
**kwargs) -> torch.Tensor:
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
use_grouped_topk: bool,
top_k: int,
router_logits: torch.Tensor,
renormalize: bool,
topk_group: int | None = None,
num_expert_group: int | None = None,
custom_routing_function: Callable | None = None,
scoring_func: str = "softmax",
routed_scaling_factor: float = 1.0,
e_score_correction_bias: torch.Tensor | None = None,
global_num_experts: int = -1,
expert_map: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
enable_force_load_balance: bool = False,
log2phy: torch.Tensor = None,
**kwargs,
) -> torch.Tensor:
zero_expert_num = getattr(layer, "zero_expert_num", 0)
zero_expert_type = getattr(layer, "zero_expert_type", None)
topk_weights, topk_ids = select_experts(
@@ -118,7 +112,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
e_score_correction_bias=e_score_correction_bias,
global_num_experts=global_num_experts)
global_num_experts=global_num_experts,
)
if zero_expert_num > 0 and zero_expert_type is not None:
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
@@ -134,11 +129,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
# to avoid accumulating too much tokens on a single rank.
# currently it is only activated when doing profile runs.
if enable_force_load_balance:
random_matrix = torch.rand(topk_ids.size(0),
global_num_experts,
device=topk_ids.device)
topk_ids = torch.argsort(
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
moe_comm_method = get_forward_context().moe_comm_method
final_hidden_states = moe_comm_method.fused_experts(
@@ -151,7 +143,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
apply_router_weight_on_input=apply_router_weight_on_input,
dynamic_eplb=self.dynamic_eplb,
log2phy=log2phy,
mc2_mask=kwargs.get("mc2_mask", None))
mc2_mask=kwargs.get("mc2_mask"),
)
if zero_expert_num > 0 and zero_expert_type is not None:
final_hidden_states += zero_expert_result
return final_hidden_states
@@ -159,7 +152,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
class AscendFusedMoE(FusedMoE):
moe_counter = -1
gate_stream: Optional[torch.npu.Stream] = None
gate_stream: torch.npu.Stream | None = None
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@@ -174,11 +167,9 @@ class AscendFusedMoE(FusedMoE):
self.log2phy = None
if self.quant_config is None:
self.quant_method = AscendUnquantizedFusedMoEMethod(
self.moe_config)
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe_config)
else:
self.quant_method = self.quant_config.get_quant_method(
self, self.layer_name)
self.quant_method = self.quant_config.get_quant_method(self, self.layer_name)
assert self.quant_method is not None
@@ -195,28 +186,32 @@ class AscendFusedMoE(FusedMoE):
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
vllm_config = get_current_vllm_config()
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
dtype=vllm_config.model_config.dtype)
dtype=vllm_config.model_config.dtype
)
# init moe
eplb_config = ascend_config.eplb_config
self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
eplb_config, self.moe_instance_id, self.moe_config)
eplb_config, self.moe_instance_id, self.moe_config
)
self.global_num_experts = num_experts + self.global_redundant_expert_num
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy
is not None)
self.local_num_experts = (torch.sum(
self._expert_map != -1).item() if self._expert_map is not None else
self.global_num_experts)
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None)
self.local_num_experts = (
torch.sum(self._expert_map != -1).item() if self._expert_map is not None else self.global_num_experts
)
if self._expert_map is not None:
logger.info_once(
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
" number of experts: %s/%s. Experts local to global index map:"
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
" %s.",
self.ep_rank,
self.ep_size,
self.local_num_experts,
self.global_num_experts,
get_compressed_expert_map(self._expert_map))
get_compressed_expert_map(self._expert_map),
)
if self.dynamic_eplb:
self.moe_load = torch.zeros(self.local_num_experts,
dtype=torch.int64).npu()
self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
self.moe_config.num_experts = self.global_num_experts
self.moe_config.num_local_experts = self.local_num_experts
@@ -225,14 +220,12 @@ class AscendFusedMoE(FusedMoE):
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": self.hidden_size,
"intermediate_size_per_partition":
self.intermediate_size_per_partition,
"intermediate_size_per_partition": self.intermediate_size_per_partition,
"params_dtype": self.params_dtype,
"weight_loader": self.weight_loader,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size
self.quant_method.create_weights(layer=self, **moe_quant_params)
@@ -243,15 +236,14 @@ class AscendFusedMoE(FusedMoE):
def _get_quant_type(self) -> QuantType:
quant_method = self.quant_method
if not hasattr(quant_method,
"quant_method") or quant_method.quant_method is None:
if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
return QuantType.NONE
method = quant_method.quant_method
if hasattr(method, "quant_type"):
from vllm_ascend.quantization.methods.base import \
QuantType as SchemeQuantType
from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType
scheme_quant_type = method.quant_type
if scheme_quant_type == SchemeQuantType.W8A8:
return QuantType.W8A8
@@ -270,22 +262,18 @@ class AscendFusedMoE(FusedMoE):
if self.moe_load is not None:
self.moe_load.zero_()
def maybe_all_reduce_tensor_model_parallel(
self, final_hidden_states: torch.Tensor):
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
the outputs are already aggregated across tensor parallel ranks in the
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
outputs since each rank only has partial outputs.
"""
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
final_hidden_states)
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
def forward_impl( # type: ignore[override]
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False
) -> torch.Tensor | FusedMoEResult:
assert self.quant_method is not None
forward_context = get_forward_context()
@@ -301,15 +289,16 @@ class AscendFusedMoE(FusedMoE):
fc3_context = get_flash_common3_context()
assert fc3_context is not None
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(AscendFusedMoE.gate_stream,
enabled=self.multistream_overlap_gate):
with npu_stream_switch(AscendFusedMoE.gate_stream, enabled=self.multistream_overlap_gate):
# share_expert
assert fc3_context.shared_experts is not None
shared_out = fc3_context.shared_experts(hidden_states)
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
set_flash_common3_context(shared_out=shared_out)
@@ -325,24 +314,22 @@ class AscendFusedMoE(FusedMoE):
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.e_score_correction_bias,
global_num_experts=self.global_num_experts)
global_num_experts=self.global_num_experts,
)
if isinstance(forward_context.moe_comm_method,
AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
topk_ids, True, True)
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
set_flash_common3_context(topk_weights=topk_weights,
topk_ids=topk_ids)
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
hidden_states=hidden_states,
router_logits=router_logits,
replace_allreduce=forward_context.sp_enabled,
enable_shared_expert_dp=self.enable_shared_expert_dp,
quant_type=self.quant_type)
quant_type=self.quant_type,
)
# Make sure the default stream waits for the gate stream to finish.
if self.multistream_overlap_gate:
@@ -375,39 +362,45 @@ class AscendFusedMoE(FusedMoE):
enable_force_load_balance=enable_force_load_balance,
log2phy=self.log2phy,
global_redundant_expert_num=self.global_redundant_expert_num,
mc2_mask=mc2_mask)
mc2_mask=mc2_mask,
)
if self.dynamic_eplb:
expert_tokens = fused_experts_results.expert_tokens
group_list_type = fused_experts_results.group_list_type
assert expert_tokens is not None and group_list_type is not None, \
assert expert_tokens is not None and group_list_type is not None, (
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
local_load = expert_tokens if group_list_type == 1 else \
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
)
local_load = (
expert_tokens
if group_list_type == 1
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
)
self.moe_load.add_(local_load)
routed_out = forward_context.moe_comm_method.finalize(
hidden_states=fused_experts_results.routed_out,
reduce_results=self.reduce_results,
context_metadata=context_metadata)
context_metadata=context_metadata,
)
if return_with_event:
return FusedMoEResult(
routed_out=routed_out,
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
before_combine_evt=fused_experts_results.before_combine_evt)
before_combine_evt=fused_experts_results.before_combine_evt,
)
else:
# The vLLM FusedMoE forward_impl does not return events.
return routed_out
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def __init__(
self,
shared_experts: torch.nn.Module,
gate: Optional[torch.nn.Module] = None,
gate: torch.nn.Module | None = None,
use_overlapped: bool = True,
routed_input_transform: Optional[torch.nn.Module] = None,
routed_input_transform: torch.nn.Module | None = None,
**kwargs,
):
AscendFusedMoE.__init__(self, **kwargs)
@@ -418,16 +411,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.use_overlapped = use_overlapped
self.shared_expert_stream = None
ascend_config = get_ascend_config()
self.multistream_overlap_shared_expert = \
ascend_config.multistream_overlap_shared_expert and \
self._shared_experts is not None
self.multistream_overlap_gate = \
ascend_config.multistream_overlap_gate and \
self._shared_experts is not None
self.multistream_overlap_shared_expert = (
ascend_config.multistream_overlap_shared_expert and self._shared_experts is not None
)
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate and self._shared_experts is not None
if enable_sp():
logger.info_once(
"Sequence parallelism is enabled, shared experts are replicated for best performance."
)
logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.")
self._gate = gate
@@ -447,20 +436,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
def _shared_experts_part1(self, hidden_states: torch.Tensor):
shared_gate_up, _ = self._shared_experts.gate_up_proj(
hidden_states) # type: ignore
shared_gate_up, _ = self._shared_experts.gate_up_proj(hidden_states) # type: ignore
return shared_gate_up
def _shared_experts_part2(self, hidden_states: torch.Tensor,
shared_gate_up: torch.Tensor):
shared_act = self._shared_experts.act_fn(
shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(
shared_act) # type: ignore
def _shared_experts_part2(self, hidden_states: torch.Tensor, shared_gate_up: torch.Tensor):
shared_act = self._shared_experts.act_fn(shared_gate_up) # type: ignore
shared_out, _ = self._shared_experts.down_proj(shared_act) # type: ignore
# Qwen3-Next specific gating mechanism
if hasattr(self._shared_experts, "expert_gate") and \
self._shared_experts.expert_gate is not None:
if hasattr(self._shared_experts, "expert_gate") and self._shared_experts.expert_gate is not None:
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
shared_out = F.sigmoid(gate_out) * shared_out
return shared_out
@@ -468,9 +452,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
def _validate_shared_expert_consistency(self):
"""Validate that split shared expert computation matches integrated
computation."""
test_input = torch.rand(
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
test_input = (
torch.rand(10, self.hidden_size, device="npu", dtype=self.moe_config.in_dtype) * 2 - 1
) # Random input for testing, scoped to [-1, 1]
integrated_out = self._shared_experts(test_input)
part1_out = self._shared_experts_part1(test_input)
@@ -478,25 +462,19 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if not torch.allclose(integrated_out, split_out):
diff = (integrated_out - split_out).abs()
logger.error(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.error("SharedFusedMoE shared experts split computation does not match the integrated computation.")
logger.error(f"Max absolute difference: {diff.max().item()}")
logger.error("Integrated output - sum: %s, norm: %s",
integrated_out.sum().item(),
integrated_out.norm().item())
logger.error("Split output - sum: %s, norm: %s",
split_out.sum().item(),
split_out.norm().item())
logger.error(
"Integrated output - sum: %s, norm: %s", integrated_out.sum().item(), integrated_out.norm().item()
)
logger.error("Split output - sum: %s, norm: %s", split_out.sum().item(), split_out.norm().item())
raise ValueError(
"SharedFusedMoE shared experts split computation does not "
"match the integrated computation.")
logger.info_once(
"SharedFusedMoE shared experts split computation matches the "
"integrated computation.")
"SharedFusedMoE shared experts split computation does not match the integrated computation."
)
logger.info_once("SharedFusedMoE shared experts split computation matches the integrated computation.")
@property
def gate(self) -> Optional[torch.nn.Module]:
def gate(self) -> torch.nn.Module | None:
return self._gate if self.use_overlapped else None
@property
@@ -530,8 +508,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
)
return shared_out, fused_out
def _forward_shared_experts(self, hidden_states: torch.Tensor,
fused_moe_evts: FusedMoEEvents):
def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents):
if self._shared_experts is None:
return None
@@ -539,11 +516,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
if evt is not None:
torch.npu.current_stream().wait_event(evt)
with npu_stream_switch(shared_experts_calculation_stream(),
enabled=self.multistream_overlap_shared_expert):
with npu_stream_switch(shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert):
# Ensure the shared experts wait for hidden_states to be ready.
torch.npu.current_stream().wait_event(
fused_moe_evts.before_routed_experts)
torch.npu.current_stream().wait_event(fused_moe_evts.before_routed_experts)
# Execute the gate projection and activation concurrently with the
# dispatch communication.
maybe_wait_event(fused_moe_evts.before_dispatch)
@@ -556,20 +531,22 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
# Make sure the default stream waits for the shared experts stream to
# finish.
if self.multistream_overlap_shared_expert:
torch.npu.current_stream().wait_stream(
shared_experts_calculation_stream())
torch.npu.current_stream().wait_stream(shared_experts_calculation_stream())
# NOTE: This is exactly the opposite of
# `maybe_all_reduce_tensor_model_parallel`
forward_context = get_forward_context()
moe_comm_type = forward_context.moe_comm_type
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
and not shared_expert_dp_enabled():
if (
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
and not shared_expert_dp_enabled()
):
shared_out = tensor_model_parallel_all_reduce(shared_out)
return shared_out
def forward_impl( # type: ignore[override]
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
):
if self.multistream_overlap_gate:
set_flash_common3_context(shared_experts=self._shared_experts)
@@ -596,6 +573,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
before_routed_experts=before_routed_experts,
before_dispatch=fused_moe_results.before_dispatch_evt,
before_combine=fused_moe_results.before_combine_evt,
))
),
)
return shared_out, routed_out

View File

@@ -17,7 +17,6 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Dict, Optional
import torch
from vllm.forward_context import get_forward_context
@@ -27,18 +26,24 @@ import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
from vllm_ascend.ops.fused_moe.prepare_finalize import (
PrepareAndFinalize, PrepareAndFinalizeWithAll2All,
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType)
PrepareAndFinalize,
PrepareAndFinalizeWithAll2All,
PrepareAndFinalizeWithAllGather,
PrepareAndFinalizeWithMC2,
QuantType,
)
from vllm_ascend.ops.fused_moe.token_dispatcher import (
MoETokenDispatcher, TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
MoETokenDispatcher,
TokenDispatcherWithAll2AllV,
TokenDispatcherWithAllGather,
TokenDispatcherWithMC2,
)
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
def get_moe_comm_method(
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
return _MoECommMethods.get(moe_comm_type, None)
def get_moe_comm_method(moe_comm_type: MoECommType | None) -> MoECommMethod | None:
return _MoECommMethods.get(moe_comm_type)
def setup_moe_comm_method(moe_config):
@@ -50,6 +55,7 @@ def setup_moe_comm_method(moe_config):
def set_gmmswigluquant_method():
from vllm_ascend.ascend_config import get_ascend_config
ascend_config = get_ascend_config()
return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant
@@ -84,51 +90,46 @@ class MoECommMethod(ABC):
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
hidden_states, router_logits, enable_shared_expert_dp,
replace_allreduce, quant_type)
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
)
return hidden_states, router_logits, mc2_mask, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
hidden_states = self.prepare_finalize.finalize(hidden_states,
reduce_results,
context_metadata)
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
return hidden_states
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
):
# Check constraints
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16, torch.int8
]
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
moe_comm_method = get_forward_context().moe_comm_method
assert moe_comm_method is not None, "Missing communication context"
@@ -143,13 +144,13 @@ class MoECommMethod(ABC):
topk_weights=topk_weights,
topk_ids=topk_ids,
expert_map=expert_map,
global_redundant_expert_num=self.moe_config.
global_redundant_expert_num,
global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
mc2_mask=mc2_mask,
apply_router_weight_on_input=apply_router_weight_on_input,
with_quant=use_int8_w8a8 or use_int4_w4a8,
dynamic_eplb=dynamic_eplb,
pertoken_scale=pertoken_scale)
pertoken_scale=pertoken_scale,
)
mlp_output = unified_apply_mlp(
hidden_states=dispatch_results.hidden_states,
@@ -168,29 +169,29 @@ class MoECommMethod(ABC):
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
fusion=use_int8_w8a8 and self.use_fusion_ops,
need_trans=need_trans,
dynamic_eplb=dynamic_eplb)
dynamic_eplb=dynamic_eplb,
)
before_combine_evt = torch.npu.current_stream().record_event()
combine_results = self.token_dispatcher.token_combine(
hidden_states=mlp_output,
context_metadata=dispatch_results.context_metadata)
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
)
return FusedExpertsResult(
routed_out=combine_results.routed_out,
before_dispatch_evt=before_dispatch_evt,
before_combine_evt=before_combine_evt,
group_list_type=dispatch_results.group_list_type,
expert_tokens=dispatch_results.group_list)
expert_tokens=dispatch_results.group_list,
)
@abstractmethod
def _get_token_dispatcher(self) -> MoETokenDispatcher:
raise NotImplementedError(
"_get_token_dispatcher function not implemented.")
raise NotImplementedError("_get_token_dispatcher function not implemented.")
@abstractmethod
def _get_prepare_finalize(self) -> PrepareAndFinalize:
raise NotImplementedError(
"_get_prepare_finalize function not implemented.")
raise NotImplementedError("_get_prepare_finalize function not implemented.")
class AllGatherCommImpl(MoECommMethod):
@@ -216,7 +217,8 @@ class AllGatherCommImpl(MoECommMethod):
return TokenDispatcherWithAllGather(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
num_local_experts=self.moe_config.num_local_experts,
)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAllGather(self.moe_config)
@@ -227,7 +229,7 @@ class MC2CommImpl(MoECommMethod):
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
@@ -253,7 +255,8 @@ class AlltoAllCommImpl(MoECommMethod):
return TokenDispatcherWithAll2AllV(
top_k=self.moe_config.experts_per_token,
num_experts=self.moe_config.num_experts,
num_local_experts=self.moe_config.num_local_experts)
num_local_experts=self.moe_config.num_local_experts,
)
def _get_prepare_finalize(self):
return PrepareAndFinalizeWithAll2All(self.moe_config)
@@ -264,7 +267,7 @@ class FusedMC2CommImpl(MoECommMethod):
1. `enable_expert_parallel=True`.
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
3. `enable_expert_parallel=False` is not supported.
This implementation uses the MC2 communication method, which is optimized for
Communication and Computation parallelism on Ascend devices.
"""
@@ -276,36 +279,36 @@ class FusedMC2CommImpl(MoECommMethod):
return PrepareAndFinalizeWithMC2(self.moe_config)
def fused_experts(
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: Optional[torch.Tensor] = None):
assert not (
w1_scale is None or w2_scale is None
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
self,
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_int8_w8a8: bool = False,
use_int4_w4a8: bool = False,
use_int4_w4a16: bool = False,
expert_map: torch.Tensor | None = None,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
# For load balance
log2phy: torch.Tensor = None,
need_trans: bool = False,
dynamic_eplb: bool = False,
mc2_mask: torch.Tensor = None,
pertoken_scale: torch.Tensor | None = None,
):
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
)
# Apply log2phy if needed
if log2phy is not None:
@@ -346,10 +349,8 @@ class FusedMC2CommImpl(MoECommMethod):
ep_rank_size=self.token_dispatcher.ep_world_size,
ep_rank_id=self.token_dispatcher.ep_rank_id,
moe_expert_num=self.moe_config.num_experts,
global_bs=self.token_dispatcher.global_bs)
global_bs=self.token_dispatcher.global_bs,
)
else:
raise ValueError(
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return FusedExpertsResult(routed_out=out,
group_list_type=group_list_type,
expert_tokens=expert_tokens)
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)

View File

@@ -14,7 +14,6 @@
# limitations under the License.
# This file is a part of the vllm-ascend project.
from typing import Optional
import torch
import torch_npu
@@ -23,24 +22,22 @@ from vllm.forward_context import get_forward_context
from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import MoECommType
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
enable_custom_op, get_ascend_device_type,
get_weight_prefetch_method)
from vllm_ascend.utils import (
dispose_tensor,
enable_custom_op,
get_weight_prefetch_method,
)
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
return fusion and dynamic_eplb and enable_custom_op()
def cumsum_group_list(group_list: torch.Tensor,
src_list_type: int,
dst_list_type: int,
active_num: int = 0,
expert_num: int = 0) -> torch.Tensor:
def cumsum_group_list(
group_list: torch.Tensor, src_list_type: int, dst_list_type: int, active_num: int = 0, expert_num: int = 0
) -> torch.Tensor:
if src_list_type not in [0, 1, 2]:
raise ValueError(
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
)
raise ValueError(f"group_list_type should be in [0, 1, 2], but received {src_list_type}")
if src_list_type == dst_list_type:
return group_list
@@ -53,10 +50,9 @@ def cumsum_group_list(group_list: torch.Tensor,
if src_list_type == 2 and dst_list_type == 0:
experts = pad(group_list[:, 0], (1, 0))
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
cumsum_group_list = torch.full(size=(expert_num, ),
fill_value=active_num,
dtype=group_list.dtype,
device=group_list.device)
cumsum_group_list = torch.full(
size=(expert_num,), fill_value=active_num, dtype=group_list.dtype, device=group_list.device
)
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
if end > start:
@@ -65,30 +61,32 @@ def cumsum_group_list(group_list: torch.Tensor,
return cumsum_group_list
raise NotImplementedError(
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
"This feature is under development.")
"This feature is under development."
)
def quant_apply_mlp(hidden_states: torch.Tensor,
w1: list[torch.Tensor],
w1_scale: list[torch.Tensor],
w2: list[torch.Tensor],
w2_scale: list[torch.Tensor],
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
fusion: bool = False,
dynamic_eplb: bool = False) -> torch.Tensor:
def quant_apply_mlp(
hidden_states: torch.Tensor,
w1: list[torch.Tensor],
w1_scale: list[torch.Tensor],
w2: list[torch.Tensor],
w2_scale: list[torch.Tensor],
group_list: torch.Tensor,
group_list_type: int = 1,
dynamic_scale: torch.Tensor = None,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
fusion: bool = False,
dynamic_eplb: bool = False,
) -> torch.Tensor:
if w1_offset is not None:
unquantized_hidden_states = hidden_states
quantized_hidden_states = None
elif dynamic_scale is None:
unquantized_hidden_states = hidden_states
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
# Dispose the original unquantized hidden states
# to save npu memory because they're no longer used.
dispose_tensor(unquantized_hidden_states)
@@ -103,22 +101,18 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight_prefetch_method = get_weight_prefetch_method()
if weight_prefetch_method:
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
hidden_states)
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
if w1_scale_bias is None and w1_offset is None and is_mc2:
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type,
0),
))
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type, 0),
)
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
@@ -126,7 +120,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
weight=w1[0],
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
x_scale=pertoken_scale,
)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
@@ -140,7 +135,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=torch.int32)[0]
output_dtype=torch.int32,
)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
@@ -165,7 +161,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=w2_scale[0].dtype)[0]
output_dtype=w2_scale[0].dtype,
)[0]
elif w1_offset is not None:
# gmm1: gate_up_proj
hidden_states = torch_npu.npu_grouped_matmul(
@@ -177,7 +174,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
output_dtype=_output_dtype,
)[0]
dispose_tensor(unquantized_hidden_states)
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
@@ -191,13 +189,12 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
output_dtype=_output_dtype,
)[0]
else:
if w1_scale_bias is not None:
if group_list_type == 0:
group_list = torch.cat(
[group_list[:1],
torch.diff(group_list, dim=0)])
group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)])
group_list_type = 1
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
bias2 = [w2_scale_bias]
@@ -206,17 +203,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = (
torch.ops._C_ascend.
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type,
0),
bias=bias1,
))
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
x=hidden_states,
weight=w1,
weight_scale=w1_scale,
x_scale=pertoken_scale,
group_list=cumsum_group_list(group_list, group_list_type, 0),
bias=bias1,
)
elif fusion and not dynamic_eplb:
# gmm1: gate_up_proj & act_fn: swiglu
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
@@ -225,7 +219,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
bias=bias1,
group_list=cumsum_group_list(group_list, group_list_type, 0),
weight_scale=w1_scale[0],
x_scale=pertoken_scale)
x_scale=pertoken_scale,
)
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
else:
@@ -241,21 +236,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
output_dtype=_output_dtype,
)[0]
if quantized_hidden_states is not None:
dispose_tensor(quantized_hidden_states)
# act_fn: swiglu
if HAS_TRITON:
from vllm_ascend.ops.triton.activation.swiglu_quant import \
swiglu_quant
from vllm_ascend.ops.triton.activation.swiglu_quant import swiglu_quant
hidden_states, swiglu_out_scale = swiglu_quant(
hidden_states,
group_list=group_list,
group_list_type=group_list_type)
hidden_states, group_list=group_list, group_list_type=group_list_type
)
else:
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
hidden_states)
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
@@ -267,18 +261,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
group_list_type=group_list_type,
group_type=0,
group_list=group_list,
output_dtype=_output_dtype)[0]
output_dtype=_output_dtype,
)[0]
return hidden_states
def unquant_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: Optional[torch.Tensor] = None,
need_trans: bool = True) -> torch.Tensor:
def unquant_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
group_list: torch.Tensor,
group_list_type: int = 1,
topk_scales: torch.Tensor | None = None,
need_trans: bool = True,
) -> torch.Tensor:
if need_trans:
w1 = w1.transpose(1, 2)
w2 = w2.transpose(1, 2)
@@ -307,44 +303,50 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
return hidden_states
def unified_apply_mlp(hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: Optional[list[torch.Tensor]] = None,
w2_scale: Optional[list[torch.Tensor]] = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: Optional[torch.Tensor] = None,
w2_offset: Optional[torch.Tensor] = None,
topk_scales: Optional[torch.Tensor] = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True,
dynamic_eplb: bool = False) -> torch.Tensor:
def unified_apply_mlp(
hidden_states: torch.Tensor,
w1: torch.Tensor | list[torch.Tensor],
w2: torch.Tensor | list[torch.Tensor],
group_list: torch.Tensor,
w1_scale: list[torch.Tensor] | None = None,
w2_scale: list[torch.Tensor] | None = None,
dynamic_scale: torch.Tensor = None,
group_list_type: int = 1,
w1_scale_bias: torch.Tensor = None,
w2_scale_bias: torch.Tensor = None,
w1_offset: torch.Tensor | None = None,
w2_offset: torch.Tensor | None = None,
topk_scales: torch.Tensor | None = None,
with_quant: bool = False,
fusion: bool = False,
need_trans: bool = True,
dynamic_eplb: bool = False,
) -> torch.Tensor:
if with_quant:
assert w1_scale is not None and w2_scale is not None
return quant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
fusion=fusion,
dynamic_eplb=dynamic_eplb)
return quant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w1_scale=w1_scale,
w2=w2,
w2_scale=w2_scale,
group_list=group_list,
dynamic_scale=dynamic_scale,
group_list_type=group_list_type,
w1_scale_bias=w1_scale_bias,
w2_scale_bias=w2_scale_bias,
w1_offset=w1_offset,
w2_offset=w2_offset,
fusion=fusion,
dynamic_eplb=dynamic_eplb,
)
else:
return unquant_apply_mlp(hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales,
need_trans=need_trans)
return unquant_apply_mlp(
hidden_states=hidden_states,
w1=w1,
w2=w2,
group_list=group_list,
group_list_type=group_list_type,
topk_scales=topk_scales,
need_trans=need_trans,
)

View File

@@ -16,22 +16,23 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Optional
import torch
import torch.distributed as dist
import torch.nn as nn
import torch_npu
from vllm.distributed.parallel_state import (
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
get_dp_group,
get_pcp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
prefill_context_parallel_enable)
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
class QuantType(Enum):
@@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC):
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
sizes, ranks, and communication settings.
"""
quant_stream: Optional[torch.npu.Stream] = None
quant_stream: torch.npu.Stream | None = None
def __init__(self, moe_config: FusedMoEConfig):
self.moe_config = moe_config
@@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type: QuantType = QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type: QuantType = QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Prepare tensors before MoE computation. May involve:
- Padding to align communication boundaries
@@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC):
"""
raise NotImplementedError("Prepare not implemented.")
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalize MoE output. May involve:
- Gathering sliced tensors across TP ranks
@@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Pad hidden_states and router_logits to next multiple of TP size.
@@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
if self.tp_size > 1:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
context_metadata = {
"padded_hidden_states_shape": padded_hidden_states_shape
}
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
return hidden_states, router_logits, None, context_metadata
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalization steps:
1. If TP > 1, all-gather slices to reconstruct full tensor.
@@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
# may share memory with original hidden_states. Since shared
# experts may use the original tensor, reusing it would cause
# in-place modification during all_gather, corrupting the data.
padded_hidden_states_shape = context_metadata[
"padded_hidden_states_shape"]
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
gathered_hidden_states = torch.empty(
padded_hidden_states_shape,
device=hidden_states.device,
dtype=hidden_states.dtype)
split_hidden_states = torch.tensor_split(
gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states,
self.moe_config.tp_group.device_group)
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
)
split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0)
dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group)
hidden_states = gathered_hidden_states
if self.num_tokens < hidden_states.shape[0]:
hidden_states = hidden_states[:self.num_tokens]
hidden_states = hidden_states[: self.num_tokens]
return hidden_states
@@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Fetch `mc2_mask` and target padding length from forward context.
@@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
# Pad if necessary (unless shared expert DP is enabled)
if pad_size > 0 and not self.enable_shared_expert_dp:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
padded_hidden_states_shape = hidden_states.shape
# Slice across TP ranks
if self.tp_size > 1 and not self.enable_shared_expert_dp:
split_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
split_router_logits = torch.tensor_split(router_logits,
self.tp_size,
dim=0)
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
hidden_states = split_hidden_states[self.tp_rank]
router_logits = split_router_logits[self.tp_rank]
@@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
AllGather hidden_states and router_logits to form global tensors.
@@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
Tuple of (global_hidden_states, global_router_logits, None)
"""
if enable_sp():
return self._prepare_with_ep_group(hidden_states, router_logits,
quant_type)
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
return self._prepare_with_dp_group(hidden_states, router_logits,
enable_shared_expert_dp,
replace_allreduce)
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
def _prepare_with_ep_group(
self,
hidden_states: torch.Tensor,
router_logits: torch.Tensor,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
pertoken_scale = None
if quant_type == QuantType.W8A8:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if self.multistream_overlap_gate:
assert PrepareAndFinalize.quant_stream is not None
PrepareAndFinalize.quant_stream.wait_stream(
torch.npu.current_stream())
with npu_stream_switch(PrepareAndFinalize.quant_stream,
enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
hidden_states)
PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream())
with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate):
hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states)
else:
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
router_logits, True, True)
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True)
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True)
if pertoken_scale is not None:
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
pertoken_scale, True, True)
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True)
if self.multistream_overlap_gate:
torch.npu.current_stream().wait_stream(
PrepareAndFinalize.quant_stream)
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
if pertoken_scale is not None:
return (hidden_states, pertoken_scale), router_logits, None, None
@@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
router_logits: torch.Tensor,
enable_shared_expert_dp: bool = False,
replace_allreduce: bool = False,
quant_type=QuantType.NONE
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
Optional[torch.Tensor]]:
quant_type=QuantType.NONE,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
"""
Preparation steps:
1. Fetch max token count across DP group from forward context.
@@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
self.num_tokens = hidden_states.shape[0]
pad_size = max_tokens_across_dp - self.num_tokens
if pad_size > 0:
hidden_states = nn.functional.pad(hidden_states,
(0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits,
(0, 0, 0, pad_size))
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
# All-gather across DP group
hidden_states = self.moe_config.dp_group.all_gather(
hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(
router_logits, 0)
hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0)
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().all_gather(
@@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return hidden_states, router_logits, None, None
def finalize(self,
hidden_states: torch.Tensor,
reduce_results: bool,
context_metadata: Optional[dict] = None) -> torch.Tensor:
def finalize(
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
) -> torch.Tensor:
"""
Finalization steps:
Reduce Scatter hidden states.
@@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
return self._finalize_with_dp_group(hidden_states, reduce_results)
def _finalize_with_ep_group(self,
hidden_states: torch.Tensor) -> torch.Tensor:
def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor:
"""
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
1. Reduce_results is False usually happens when models have shared experts and need to
@@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
here, then skip allreudce in FusedMoe.
"""
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
hidden_states, True)
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True)
return hidden_states
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
reduce_results: bool) -> torch.Tensor:
def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor:
"""
Finalization steps:
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
@@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
"""
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
hidden_states = hidden_states[:self.num_tokens]
hidden_states = hidden_states[: self.num_tokens]
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
dim=0)
hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0)
return hidden_states

View File

@@ -22,7 +22,6 @@
# limitations under the License.
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Optional
import torch
import torch_npu
@@ -30,10 +29,8 @@ from vllm.config import get_current_vllm_config
from vllm.distributed.parallel_state import get_ep_group
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.fused_moe.comm_utils import (
async_all_to_all, gather_from_sequence_parallel_region)
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
is_hierarchical_communication_enabled)
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
@dataclass
@@ -52,7 +49,6 @@ class TokenCombineResult:
class MoETokenDispatcher(ABC):
def __init__(self, **kwargs) -> None:
"""
Initialize the MoE Token Dispatcher.
@@ -79,26 +75,24 @@ class MoETokenDispatcher(ABC):
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None,
pertoken_scale: torch.Tensor | None = None,
) -> TokenDispatchResult:
raise NotImplementedError("Dispatch function not implemented.")
@abstractmethod
def token_combine(self,
hidden_states: torch.Tensor,
context_metadata: dict,
bias: torch.Tensor | None = None) -> TokenCombineResult:
def token_combine(
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
) -> TokenCombineResult:
raise NotImplementedError("Combine function not implemented.")
class TokenDispatcherWithMC2(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
device_group = get_mc2_group().device_group
@@ -108,10 +102,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
self.ep_rank_id = get_mc2_group().rank_in_group
self.ep_world_size = get_mc2_group().world_size
self.enable_dispatch_v2 = hasattr(torch_npu,
"npu_moe_distribute_dispatch_v2")
self.need_extra_args = (
get_ascend_device_type() == AscendDeviceType.A3)
self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
@@ -126,10 +118,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
compilation_config = vllm_config.compilation_config
speculative_config = vllm_config.speculative_config
tp_size = vllm_config.parallel_config.tensor_parallel_size
uniform_decode_query_len = 1 if not speculative_config else \
1 + speculative_config.num_speculative_tokens
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
0)
uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens
decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
if compilation_config.cudagraph_capture_sizes:
max_num_tokens = compilation_config.max_cudagraph_capture_size
@@ -167,44 +157,56 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"ep_rank_id": self.ep_rank_id,
}
if self.need_extra_args:
stage1_kwargs.update({
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
stage1_kwargs.update(
{
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
}
)
if self.need_expert_scale:
stage1_kwargs.update({
"expert_scales":
topk_weights.to(torch.float32),
})
stage1_kwargs.update(
{
"expert_scales": topk_weights.to(torch.float32),
}
)
kwargs_mc2.update(stage1_kwargs)
return kwargs_mc2
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
topk_ids, expert_map,
mc2_mask,
global_redundant_expert_num)
output = torch_npu.npu_moe_distribute_dispatch_v2(
**kwargs_mc2
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
**kwargs_mc2)
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
)
output = (
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
if self.enable_dispatch_v2
else torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
)
# comm_stream.wait_stream(torch.npu.current_stream())
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
(
expand_x,
dynamic_scale,
assist_info_for_combine,
expert_token_nums,
ep_recv_counts,
tp_recv_counts,
expand_scales,
) = output[0:7]
context_metadata = {
"topk_ids": topk_ids,
@@ -213,18 +215,19 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
"ep_recv_counts": ep_recv_counts,
"tp_recv_counts": tp_recv_counts,
"assist_info_for_combine": assist_info_for_combine,
"expand_scales": expand_scales
"expand_scales": expand_scales,
}
group_list_type = 0
return TokenDispatchResult(hidden_states=expand_x,
dynamic_scale=dynamic_scale,
group_list=expert_token_nums,
group_list_type=group_list_type,
context_metadata=context_metadata)
return TokenDispatchResult(
hidden_states=expand_x,
dynamic_scale=dynamic_scale,
group_list=expert_token_nums,
group_list_type=group_list_type,
context_metadata=context_metadata,
)
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
context_metadata: dict):
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
expert_map = context_metadata["expert_map"]
topk_ids = context_metadata["topk_ids"]
topk_weights = context_metadata["topk_weights"]
@@ -246,9 +249,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
}
if self.with_quant:
tp_recv_counts = torch.empty(1,
dtype=torch.int32,
device=hidden_states.device)
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
stage3_kwargs = {
"ep_send_counts": ep_recv_counts,
@@ -264,12 +265,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
stage3_kwargs["expand_idx"] = assist_info_for_combine
if self.need_extra_args:
stage3_kwargs.update({
"tp_send_counts": tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
})
stage3_kwargs.update(
{
"tp_send_counts": tp_recv_counts,
"group_tp": self.moe_all_to_all_group_name,
"tp_world_size": 1,
"tp_rank_id": 0,
}
)
kwargs_mc2.update(stage3_kwargs)
return kwargs_mc2
@@ -277,57 +280,58 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
def token_combine(self, hidden_states, context_metadata, bias=None):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
context_metadata)
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
combined_output = (
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
if self.enable_dispatch_v2
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
)
return TokenCombineResult(routed_out=combined_output, )
return TokenCombineResult(
routed_out=combined_output,
)
class TokenDispatcherWithAllGather(MoETokenDispatcher):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.apply_router_weight_on_input = False
self.max_num_tokens = kwargs.get("max_num_tokens")
num_experts_local = kwargs.get("num_local_experts", 0)
self.num_experts_local = num_experts_local.item() if torch.is_tensor(
num_experts_local) else int(num_experts_local)
self.num_experts_local = (
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
)
self.original_shape = None
self.with_quant = False
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant
self.original_shape = hidden_states.shape
num_tokens = hidden_states.shape[:-1].numel()
self.apply_router_weight_on_input = apply_router_weight_on_input
if self.apply_router_weight_on_input:
assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)"
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
_, topk = topk_weights.shape
assert (
topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * \
topk_weights.to(hidden_states.dtype)
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
if expert_map is not None:
global_num_experts = len(expert_map) + global_redundant_expert_num
mask = (expert_map[topk_ids] != -1)
mask = expert_map[topk_ids] != -1
topk_weights = topk_weights * mask
first_expert_idx = get_ep_group(
).rank_in_group * self.num_experts_local
first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local
last_expert_idx = first_expert_idx + self.num_experts_local
else:
first_expert_idx = 0
@@ -344,15 +348,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
expert_tokens_num_type=1,
expert_tokens_num_flag=True,
active_expert_range=[first_expert_idx, last_expert_idx],
quant_mode=1
if self.with_quant and pertoken_scale is None else -1,
))
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
)
)
expert_tokens = expert_tokens.to(torch.int64)
group_list_type = 1 # `count` mode
context_metadata = {
"topk_weights": topk_weights,
"expanded_row_idx": expanded_row_idx
}
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
return TokenDispatchResult(
hidden_states=sorted_hidden_states,
@@ -367,7 +368,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
final_hidden_states = torch_npu.npu_moe_token_unpermute(
permuted_tokens=hidden_states,
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
probs=context_metadata["topk_weights"])
probs=context_metadata["topk_weights"],
)
if len(self.original_shape) == 3:
final_hidden_states = final_hidden_states.view(self.original_shape)
@@ -398,35 +400,33 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
device=torch.npu.current_device(),
)
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
local_expert_indices_offset = self.ep_rank * self.num_local_experts
self.local_expert_indices = [
local_expert_indices_offset + i
for i in range(self.num_local_experts)
]
assert (len(self.local_expert_indices) == self.num_local_experts
), "Invalid local expert indices"
self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]
assert len(self.local_expert_indices) == self.num_local_experts, "Invalid local expert indices"
for i in range(len(self.local_expert_indices) - 1):
assert (self.local_expert_indices[i] ==
self.local_expert_indices[i + 1] -
1), "local_expert_indices must be continuous"
assert self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1, (
"local_expert_indices must be continuous"
)
# TODO: Try local_rank = ep_group.rank_in_group
local_rank = torch.distributed.get_rank(group=self.ep_group)
backend = self.ep_group._get_backend(torch.device("npu"))
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
def token_dispatch(self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: Optional[torch.Tensor] = None,
global_redundant_expert_num: int = 0,
mc2_mask: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: Optional[torch.Tensor] = None):
def token_dispatch(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
expert_map: torch.Tensor | None = None,
global_redundant_expert_num: int = 0,
mc2_mask: torch.Tensor | None = None,
apply_router_weight_on_input: bool = False,
with_quant: bool = False,
dynamic_eplb: bool = False,
pertoken_scale: torch.Tensor | None = None,
):
self.with_quant = with_quant
self.hidden_shape = hidden_states.shape
@@ -442,35 +442,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
dynamic_scale_after_all2all = None
if self.with_quant:
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
permutated_local_input_tokens)
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
dynamic_scale, output_splits, input_splits, self.ep_group)
dynamic_scale, output_splits, input_splits, self.ep_group
)
permute2_ep_all_to_all_handle.wait()
dynamic_scale.untyped_storage().resize_(0)
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
permutated_local_input_tokens, output_splits, input_splits,
self.ep_group)
permutated_local_input_tokens, output_splits, input_splits, self.ep_group
)
permute1_ep_all_to_all_handle.wait()
permutated_local_input_tokens.untyped_storage().resize_(0)
# Postprocess
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices)
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
self._dispatch_postprocess(
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
)
)
context_metadata = {
"input_splits":
input_splits,
"output_splits":
output_splits,
"topk_weights":
topk_weights,
"reversed_local_input_permutation_mapping":
reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping":
reversed_global_input_permutation_mapping
"input_splits": input_splits,
"output_splits": output_splits,
"topk_weights": topk_weights,
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
}
return TokenDispatchResult(
@@ -485,8 +482,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
# 1. Preprocess using metadata
hidden_states = self._combine_preprocess(hidden_states,
context_metadata)
hidden_states = self._combine_preprocess(hidden_states, context_metadata)
# 2. AllToAll
_, permutated_local_input_tokens, handle = async_all_to_all(
@@ -499,8 +495,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
hidden_states.untyped_storage().resize_(0)
# 3. Postprocess using metadata
output = self._combine_postprocess(permutated_local_input_tokens,
context_metadata)
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
return TokenCombineResult(routed_out=output)
@@ -534,42 +529,39 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
)
def _preprocess(self, topk_ids: torch.Tensor):
num_local_tokens_per_expert = torch.histc(topk_ids,
bins=self.num_experts,
min=0,
max=self.num_experts)
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
ep_size = self.ep_size
self.num_out_tokens = topk_ids.numel()
input_splits = (num_local_tokens_per_expert.reshape(
ep_size,
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
non_blocking=True).numpy())
input_splits = (
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
.sum(axis=1)
.to(torch.device("cpu"), non_blocking=True)
.numpy()
)
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
num_local_tokens_per_expert,
group=self.ep_group).reshape(ep_size, self.num_experts)
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
0]:self.local_expert_indices[-1] + 1]
num_local_tokens_per_expert, group=self.ep_group
).reshape(ep_size, self.num_experts)
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
]
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before sum.")
raise ValueError("num_global_tokens_per_local_expert must be set before sum.")
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
torch.device("cpu"), non_blocking=True).numpy())
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
axis=0)
output_splits = (
num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()
)
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(axis=0)
global_input_tokens_local_experts_indices = None
if self.num_local_experts > 1:
if num_global_tokens_per_local_expert is None:
raise ValueError(
"num_global_tokens_per_local_expert must be set before operations."
)
raise ValueError("num_global_tokens_per_local_expert must be set before operations.")
global_input_tokens_local_experts_indices = torch.repeat_interleave(
self.expert_ids_per_ep_rank,
num_global_tokens_per_local_expert.ravel())
self.expert_ids_per_ep_rank, num_global_tokens_per_local_expert.ravel()
)
else:
torch.npu.synchronize()
@@ -581,45 +573,41 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
global_input_tokens_local_experts_indices,
)
def _dispatch_postprocess(self, global_input_tokens,
dynamic_scale_after_all2all,
global_input_tokens_local_experts_indices):
def _dispatch_postprocess(
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
):
# Early return if no local experts or no tokens
if self.num_local_experts <= 1:
return global_input_tokens, dynamic_scale_after_all2all, None
# Handle quantized case
if self.with_quant:
assert global_input_tokens_local_experts_indices is not None, \
assert global_input_tokens_local_experts_indices is not None, (
"global_input_tokens_local_experts_indices must be provided"
)
dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute(
dynamic_scale_after_all2all.unsqueeze(-1),
global_input_tokens_local_experts_indices)
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(
-1)
dynamic_scale_after_all2all.unsqueeze(-1), global_input_tokens_local_experts_indices
)
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(-1)
# Non-quantized case
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
global_input_tokens, global_input_tokens_local_experts_indices)
global_input_tokens, global_input_tokens_local_experts_indices
)
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
def _combine_preprocess(self, hidden_states: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
# Unpermutation 2: expert output to AlltoAll input
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
rev_global = context_metadata[
"reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute(
hidden_states, rev_global)
rev_global = context_metadata["reversed_global_input_permutation_mapping"]
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
return hidden_states
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
context_metadata: dict) -> torch.Tensor:
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
# Unpermutation 1: AlltoAll output to output
output = torch_npu.npu_moe_token_unpermute(
permuted_tokens=permutated_local_input_tokens,
sorted_indices=context_metadata[
"reversed_local_input_permutation_mapping"].to(torch.int32),
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
probs=context_metadata["topk_weights"],
restore_shape=self.hidden_shape_before_permute,
)