### What this PR does / why we need it?
**Scope of Changes**:
| File Path |
| :--- |
| `vllm_ascend/ops/fused_moe/comm_utils.py` |
| `vllm_ascend/ops/fused_moe/experts_selector.py` |
| `vllm_ascend/ops/fused_moe/fused_moe.py` |
| `vllm_ascend/ops/fused_moe/moe_comm_method.py` |
| `vllm_ascend/ops/fused_moe/moe_mlp.py` |
| `vllm_ascend/ops/fused_moe/prepare_finalize.py` |
| `vllm_ascend/ops/fused_moe/token_dispatcher.py` |
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.14.0
- vLLM main:
d68209402d
Signed-off-by: MrZ20 <2609716663@qq.com>
Signed-off-by: SILONG ZENG <2609716663@qq.com>
This commit is contained in:
@@ -62,9 +62,6 @@ exclude = [
|
||||
"vllm_ascend/worker/v2/**",
|
||||
"vllm_ascend/worker/npu_input_batch.py",
|
||||
"vllm_ascend/ops/rotary_embedding.py",
|
||||
|
||||
# (11)
|
||||
"vllm_ascend/ops/fused_moe/**",
|
||||
]
|
||||
|
||||
[tool.ruff.lint]
|
||||
|
||||
@@ -23,11 +23,7 @@ import torch_npu
|
||||
COMM_STREAM = None
|
||||
|
||||
|
||||
def async_all_to_all(input_,
|
||||
output_split_sizes,
|
||||
input_split_sizes,
|
||||
group,
|
||||
event=None):
|
||||
def async_all_to_all(input_, output_split_sizes, input_split_sizes, group, event=None):
|
||||
if output_split_sizes is None:
|
||||
# Equal split (all2all)
|
||||
a2a_out = torch.empty_like(input_)
|
||||
@@ -43,8 +39,7 @@ def async_all_to_all(input_,
|
||||
# multi stream wait event
|
||||
global COMM_STREAM
|
||||
if COMM_STREAM is None:
|
||||
COMM_STREAM = torch_npu.npu.Stream(
|
||||
device=torch.npu.current_device())
|
||||
COMM_STREAM = torch_npu.npu.Stream(device=torch.npu.current_device())
|
||||
with torch_npu.npu.stream(COMM_STREAM):
|
||||
event.wait()
|
||||
handle = dist.all_to_all_single(
|
||||
@@ -53,14 +48,17 @@ def async_all_to_all(input_,
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
async_op=True,
|
||||
)
|
||||
else:
|
||||
handle = dist.all_to_all_single(a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True)
|
||||
handle = dist.all_to_all_single(
|
||||
a2a_out,
|
||||
input_.contiguous(),
|
||||
output_split_sizes=output_split_sizes,
|
||||
input_split_sizes=input_split_sizes,
|
||||
group=group,
|
||||
async_op=True,
|
||||
)
|
||||
return input_, a2a_out, handle
|
||||
|
||||
|
||||
@@ -86,19 +84,12 @@ def _gather_along_first_dim(input_, group, output_split_sizes=None):
|
||||
if output_split_sizes is None:
|
||||
dim_size[0] = dim_size[0] * world_size
|
||||
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output,
|
||||
input_.contiguous(),
|
||||
group=group)
|
||||
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
|
||||
torch.distributed.all_gather_into_tensor(output, input_.contiguous(), group=group)
|
||||
else:
|
||||
dim_size[0] = sum(output_split_sizes)
|
||||
output = torch.empty(dim_size,
|
||||
dtype=input_.dtype,
|
||||
device=torch.npu.current_device())
|
||||
output_tensor_list = list(
|
||||
torch.split(output, output_split_sizes, dim=0))
|
||||
output = torch.empty(dim_size, dtype=input_.dtype, device=torch.npu.current_device())
|
||||
output_tensor_list = list(torch.split(output, output_split_sizes, dim=0))
|
||||
torch.distributed.all_gather(output_tensor_list, input_, group=group)
|
||||
|
||||
return output
|
||||
@@ -110,4 +101,4 @@ def gather_from_sequence_parallel_region(
|
||||
output_split_sizes=None,
|
||||
):
|
||||
"""Wrapper for autograd function: forward: AG, backward: RS <first dim>"""
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
return _gather_along_first_dim(input_, group, output_split_sizes)
|
||||
|
||||
@@ -14,26 +14,28 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from typing import Callable, Optional
|
||||
from collections.abc import Callable
|
||||
|
||||
import torch
|
||||
|
||||
from vllm_ascend.utils import get_weight_prefetch_method
|
||||
|
||||
|
||||
def select_experts(hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
indices_type: Optional[torch.dtype] = None,
|
||||
global_num_experts: int = -1):
|
||||
def select_experts(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
indices_type: torch.dtype | None = None,
|
||||
global_num_experts: int = -1,
|
||||
):
|
||||
"""
|
||||
Fused experts with select experts.
|
||||
|
||||
@@ -58,8 +60,7 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
# prefetch w1_w3_proj.weight preprocess
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(
|
||||
hidden_states, "gate_up")
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_preprocess(hidden_states, "gate_up")
|
||||
is_support_npu_moe_gating_top_k = check_npu_moe_gating_top_k(
|
||||
hidden_states=hidden_states,
|
||||
top_k=top_k,
|
||||
@@ -67,7 +68,8 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
custom_routing_function=custom_routing_function)
|
||||
custom_routing_function=custom_routing_function,
|
||||
)
|
||||
|
||||
if is_support_npu_moe_gating_top_k:
|
||||
topk_weights, topk_ids = _select_experts_with_fusion_ops(
|
||||
@@ -81,7 +83,8 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
num_expert_group=num_expert_group,
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
global_num_experts=global_num_experts)
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
else:
|
||||
topk_weights, topk_ids = _native_select_experts(
|
||||
hidden_states=hidden_states,
|
||||
@@ -100,14 +103,15 @@ def select_experts(hidden_states: torch.Tensor,
|
||||
|
||||
|
||||
def check_npu_moe_gating_top_k(
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
scoring_func: str = "softmax",
|
||||
custom_routing_function: Optional[Callable] = None):
|
||||
if scoring_func == "sigmoid" and not renormalize: #sigmoid + renorm=0 is not supported in current branch
|
||||
hidden_states: torch.Tensor,
|
||||
top_k: int,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
custom_routing_function: Callable | None = None,
|
||||
):
|
||||
if scoring_func == "sigmoid" and not renormalize: # sigmoid + renorm=0 is not supported in current branch
|
||||
return False
|
||||
if custom_routing_function is not None:
|
||||
return False
|
||||
@@ -115,39 +119,39 @@ def check_npu_moe_gating_top_k(
|
||||
return False
|
||||
topk_group = topk_group if topk_group is not None else 1
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||
if not (num_expert_group > 0 and hidden_states.shape[-1] % num_expert_group
|
||||
== 0 and hidden_states.shape[-1] // num_expert_group > 2):
|
||||
if not (
|
||||
num_expert_group > 0
|
||||
and hidden_states.shape[-1] % num_expert_group == 0
|
||||
and hidden_states.shape[-1] // num_expert_group > 2
|
||||
):
|
||||
return False
|
||||
if topk_group < 1 or topk_group > num_expert_group:
|
||||
return False
|
||||
if top_k < 1 or \
|
||||
top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
||||
if top_k < 1 or top_k > (hidden_states.shape[-1] / (num_expert_group * topk_group)):
|
||||
return False
|
||||
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k:
|
||||
if topk_group * hidden_states.shape[-1] / num_expert_group < top_k: # noqa: SIM103
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _native_grouped_topk(
|
||||
topk_weights: torch.Tensor,
|
||||
num_expert_group: Optional[int],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: int | None,
|
||||
topk_group: int | None,
|
||||
):
|
||||
topk_group = 0 if topk_group is None else topk_group
|
||||
num_expert_group = 0 if num_expert_group is None else num_expert_group
|
||||
|
||||
num_token = topk_weights.shape[0]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group,
|
||||
-1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32),
|
||||
k=topk_group,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
grouped_weights = topk_weights.view(num_token, num_expert_group, -1).max(dim=-1).values
|
||||
topk_group_indices = torch.topk(grouped_weights.to(torch.float32), k=topk_group, dim=-1, sorted=False)[1]
|
||||
topk_group_mask = torch.zeros_like(grouped_weights)
|
||||
topk_group_mask.scatter_(1, topk_group_indices, 1)
|
||||
topk_weight_mask = (topk_group_mask.unsqueeze(-1).expand(
|
||||
num_token, num_expert_group,
|
||||
topk_weights.shape[-1] // num_expert_group).reshape(num_token, -1))
|
||||
topk_weight_mask = (
|
||||
topk_group_mask.unsqueeze(-1)
|
||||
.expand(num_token, num_expert_group, topk_weights.shape[-1] // num_expert_group)
|
||||
.reshape(num_token, -1)
|
||||
)
|
||||
topk_weights = topk_weights.masked_fill(~topk_weight_mask.bool(), 0.0)
|
||||
|
||||
return topk_weights
|
||||
@@ -163,9 +167,13 @@ def _renormalize_topk_weights(
|
||||
|
||||
|
||||
def _select_expert_use_group_topk(
|
||||
topk_weights: torch.Tensor, topk_group: Optional[int],
|
||||
renormalize: bool, top_k: int, num_expert_group: Optional[int],
|
||||
e_score_correction_bias: Optional[torch.Tensor]):
|
||||
topk_weights: torch.Tensor,
|
||||
topk_group: int | None,
|
||||
renormalize: bool,
|
||||
top_k: int,
|
||||
num_expert_group: int | None,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
):
|
||||
assert topk_group is not None
|
||||
assert num_expert_group is not None
|
||||
|
||||
@@ -177,47 +185,38 @@ def _select_expert_use_group_topk(
|
||||
|
||||
# TODO: Change to npu_group_topk when the latest CANN and NNAL is available
|
||||
# >>> torch_npu._npu_group_topk(topk_weights, group_num=num_expert_group, k=topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group,
|
||||
topk_group)
|
||||
topk_weights = _native_grouped_topk(topk_weights, num_expert_group, topk_group)
|
||||
# TODO bfloat16 is not supported in torch.topk with ge graph.
|
||||
if e_score_correction_bias is not None:
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)[1]
|
||||
topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)[1]
|
||||
# Use original unbiased scores for the routing weights
|
||||
topk_weights = original_weights.gather(1, topk_ids)
|
||||
else:
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32),
|
||||
k=top_k,
|
||||
dim=-1,
|
||||
sorted=False)
|
||||
topk_weights, topk_ids = torch.topk(topk_weights.to(torch.float32), k=top_k, dim=-1, sorted=False)
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
topk_weights = _renormalize_topk_weights(topk_weights, renormalize)
|
||||
return topk_weights, topk_ids
|
||||
|
||||
|
||||
def _select_experts_with_fusion_ops(
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: Optional[torch.Tensor],
|
||||
topk_group: Optional[int],
|
||||
num_expert_group: Optional[int],
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1):
|
||||
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
e_score_correction_bias: torch.Tensor | None,
|
||||
topk_group: int | None,
|
||||
num_expert_group: int | None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor=1.0,
|
||||
global_num_experts: int = -1,
|
||||
):
|
||||
topk_group = topk_group if topk_group is not None else 1
|
||||
num_expert_group = num_expert_group if num_expert_group is not None else 1
|
||||
renorm = int(renormalize)
|
||||
norm_type = 0 if scoring_func == "softmax" else 1
|
||||
if e_score_correction_bias is not None and \
|
||||
e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(
|
||||
router_logits.dtype)
|
||||
if e_score_correction_bias is not None and e_score_correction_bias.dtype != router_logits.dtype:
|
||||
e_score_correction_bias = e_score_correction_bias.to(router_logits.dtype)
|
||||
topk_weights, topk_ids, _ = torch.ops._C_ascend.moe_gating_top_k(
|
||||
router_logits,
|
||||
k=top_k,
|
||||
@@ -228,7 +227,7 @@ def _select_experts_with_fusion_ops(
|
||||
norm_type=norm_type, # 0: softmax; 1: sigmoid
|
||||
out_flag=False,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
eps=float(1e-20),
|
||||
eps=1e-20,
|
||||
bias_opt=e_score_correction_bias,
|
||||
)
|
||||
|
||||
@@ -241,12 +240,12 @@ def _native_select_experts(
|
||||
top_k: int,
|
||||
use_grouped_topk: bool,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: Optional[torch.Tensor] = None
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
global_num_experts: torch.Tensor | None = None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Select top-k experts based on router logits.
|
||||
@@ -285,7 +284,8 @@ def _native_select_experts(
|
||||
renormalize=renormalize,
|
||||
topk_group=topk_group,
|
||||
num_expert_group=num_expert_group,
|
||||
e_score_correction_bias=e_score_correction_bias)
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
)
|
||||
|
||||
if custom_routing_function is not None:
|
||||
topk_weights, topk_ids = custom_routing_function(
|
||||
@@ -293,7 +293,8 @@ def _native_select_experts(
|
||||
gating_output=router_logits,
|
||||
topk=top_k,
|
||||
renormalize=renormalize,
|
||||
global_num_experts=global_num_experts)
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
# Required by npu_moe_init_routing
|
||||
topk_ids = topk_ids.to(torch.int32)
|
||||
return topk_weights, topk_ids
|
||||
@@ -318,8 +319,7 @@ def zero_experts_compute(
|
||||
if zero_expert_type == "identity":
|
||||
zero_expert_mask = expert_indices < num_experts
|
||||
zero_expert_scales = expert_scales.clone()
|
||||
zero_expert_scales = torch.where(zero_expert_mask, 0.0,
|
||||
zero_expert_scales)
|
||||
zero_expert_scales = torch.where(zero_expert_mask, 0.0, zero_expert_scales)
|
||||
|
||||
hidden_states = hidden_states.unsqueeze(1)
|
||||
zero_expert_scales = zero_expert_scales.unsqueeze(2)
|
||||
|
||||
@@ -14,40 +14,37 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass, field
|
||||
from functools import wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed import (get_dp_group, get_ep_group, get_tp_group,
|
||||
tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group, tensor_model_parallel_all_reduce
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import logger
|
||||
from vllm.model_executor.layers.fused_moe.config import FusedMoEConfig
|
||||
from vllm.model_executor.layers.fused_moe.layer import (
|
||||
FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map)
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import \
|
||||
SharedFusedMoE
|
||||
from vllm.model_executor.layers.fused_moe.layer import FusedMoE, UnquantizedFusedMoEMethod, get_compressed_expert_map
|
||||
from vllm.model_executor.layers.fused_moe.shared_fused_moe import SharedFusedMoE
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.eplb.core.eplb_utils import init_eplb_config
|
||||
from vllm_ascend.flash_common3_context import (get_flash_common3_context,
|
||||
set_flash_common3_context)
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import (select_experts,
|
||||
zero_experts_compute)
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import (AllGatherCommImpl,
|
||||
FusedExpertsResult,
|
||||
setup_moe_comm_method)
|
||||
from vllm_ascend.flash_common3_context import get_flash_common3_context, set_flash_common3_context
|
||||
from vllm_ascend.ops.fused_moe.experts_selector import select_experts, zero_experts_compute
|
||||
from vllm_ascend.ops.fused_moe.moe_comm_method import AllGatherCommImpl, FusedExpertsResult, setup_moe_comm_method
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import QuantType
|
||||
from vllm_ascend.utils import (AscendDeviceType, enable_sp,
|
||||
get_ascend_device_type, maybe_trans_nz,
|
||||
npu_stream_switch, shared_expert_dp_enabled,
|
||||
shared_experts_calculation_stream,
|
||||
vllm_version_is)
|
||||
from vllm_ascend.utils import (
|
||||
enable_sp,
|
||||
maybe_trans_nz,
|
||||
npu_stream_switch,
|
||||
shared_expert_dp_enabled,
|
||||
shared_experts_calculation_stream,
|
||||
vllm_version_is,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class FusedMoEResult:
|
||||
@@ -64,46 +61,43 @@ class FusedMoEEvents:
|
||||
|
||||
|
||||
class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
def __init__(self, moe: FusedMoEConfig = None):
|
||||
|
||||
super().__init__(moe=moe)
|
||||
self.dynamic_eplb = get_ascend_config().eplb_config.dynamic_eplb
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
super(UnquantizedFusedMoEMethod,
|
||||
self).process_weights_after_loading(layer)
|
||||
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(1, 2).contiguous()
|
||||
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||
1, 2).contiguous()
|
||||
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(1, 2).contiguous()
|
||||
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||
|
||||
layer.w13_weight.data = maybe_trans_nz(layer.w13_weight.data)
|
||||
layer.w2_weight.data = maybe_trans_nz(layer.w2_weight.data)
|
||||
|
||||
def apply(self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: Optional[int] = None,
|
||||
num_expert_group: Optional[int] = None,
|
||||
custom_routing_function: Optional[Callable] = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: Optional[torch.Tensor] = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
**kwargs) -> torch.Tensor:
|
||||
def apply(
|
||||
self,
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
use_grouped_topk: bool,
|
||||
top_k: int,
|
||||
router_logits: torch.Tensor,
|
||||
renormalize: bool,
|
||||
topk_group: int | None = None,
|
||||
num_expert_group: int | None = None,
|
||||
custom_routing_function: Callable | None = None,
|
||||
scoring_func: str = "softmax",
|
||||
routed_scaling_factor: float = 1.0,
|
||||
e_score_correction_bias: torch.Tensor | None = None,
|
||||
global_num_experts: int = -1,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
enable_force_load_balance: bool = False,
|
||||
log2phy: torch.Tensor = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
zero_expert_num = getattr(layer, "zero_expert_num", 0)
|
||||
zero_expert_type = getattr(layer, "zero_expert_type", None)
|
||||
topk_weights, topk_ids = select_experts(
|
||||
@@ -118,7 +112,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
scoring_func=scoring_func,
|
||||
routed_scaling_factor=routed_scaling_factor,
|
||||
e_score_correction_bias=e_score_correction_bias,
|
||||
global_num_experts=global_num_experts)
|
||||
global_num_experts=global_num_experts,
|
||||
)
|
||||
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
topk_ids, topk_weights, zero_expert_result = zero_experts_compute(
|
||||
@@ -134,11 +129,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
# to avoid accumulating too much tokens on a single rank.
|
||||
# currently it is only activated when doing profile runs.
|
||||
if enable_force_load_balance:
|
||||
random_matrix = torch.rand(topk_ids.size(0),
|
||||
global_num_experts,
|
||||
device=topk_ids.device)
|
||||
topk_ids = torch.argsort(
|
||||
random_matrix, dim=1)[:, :topk_ids.size(1)].to(topk_ids.dtype)
|
||||
random_matrix = torch.rand(topk_ids.size(0), global_num_experts, device=topk_ids.device)
|
||||
topk_ids = torch.argsort(random_matrix, dim=1)[:, : topk_ids.size(1)].to(topk_ids.dtype)
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
final_hidden_states = moe_comm_method.fused_experts(
|
||||
@@ -151,7 +143,8 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
dynamic_eplb=self.dynamic_eplb,
|
||||
log2phy=log2phy,
|
||||
mc2_mask=kwargs.get("mc2_mask", None))
|
||||
mc2_mask=kwargs.get("mc2_mask"),
|
||||
)
|
||||
if zero_expert_num > 0 and zero_expert_type is not None:
|
||||
final_hidden_states += zero_expert_result
|
||||
return final_hidden_states
|
||||
@@ -159,7 +152,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
|
||||
class AscendFusedMoE(FusedMoE):
|
||||
moe_counter = -1
|
||||
gate_stream: Optional[torch.npu.Stream] = None
|
||||
gate_stream: torch.npu.Stream | None = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
@@ -174,11 +167,9 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.log2phy = None
|
||||
|
||||
if self.quant_config is None:
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(
|
||||
self.moe_config)
|
||||
self.quant_method = AscendUnquantizedFusedMoEMethod(self.moe_config)
|
||||
else:
|
||||
self.quant_method = self.quant_config.get_quant_method(
|
||||
self, self.layer_name)
|
||||
self.quant_method = self.quant_config.get_quant_method(self, self.layer_name)
|
||||
|
||||
assert self.quant_method is not None
|
||||
|
||||
@@ -195,28 +186,32 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.custom_routing_function is None and self.e_score_correction_bias is not None:
|
||||
vllm_config = get_current_vllm_config()
|
||||
self.e_score_correction_bias.data = self.e_score_correction_bias.data.to(
|
||||
dtype=vllm_config.model_config.dtype)
|
||||
dtype=vllm_config.model_config.dtype
|
||||
)
|
||||
|
||||
# init moe
|
||||
eplb_config = ascend_config.eplb_config
|
||||
self.global_expert_map, self._expert_map, self.log2phy, self.global_redundant_expert_num = init_eplb_config(
|
||||
eplb_config, self.moe_instance_id, self.moe_config)
|
||||
eplb_config, self.moe_instance_id, self.moe_config
|
||||
)
|
||||
self.global_num_experts = num_experts + self.global_redundant_expert_num
|
||||
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy
|
||||
is not None)
|
||||
self.local_num_experts = (torch.sum(
|
||||
self._expert_map != -1).item() if self._expert_map is not None else
|
||||
self.global_num_experts)
|
||||
self.dynamic_eplb = eplb_config.dynamic_eplb and (self.log2phy is not None)
|
||||
self.local_num_experts = (
|
||||
torch.sum(self._expert_map != -1).item() if self._expert_map is not None else self.global_num_experts
|
||||
)
|
||||
if self._expert_map is not None:
|
||||
logger.info_once(
|
||||
"[EP Rank %s/%s] Expert parallelism is enabled. Local/global"
|
||||
" number of experts: %s/%s. Experts local to global index map:"
|
||||
" %s.", self.ep_rank, self.ep_size, self.local_num_experts,
|
||||
" %s.",
|
||||
self.ep_rank,
|
||||
self.ep_size,
|
||||
self.local_num_experts,
|
||||
self.global_num_experts,
|
||||
get_compressed_expert_map(self._expert_map))
|
||||
get_compressed_expert_map(self._expert_map),
|
||||
)
|
||||
if self.dynamic_eplb:
|
||||
self.moe_load = torch.zeros(self.local_num_experts,
|
||||
dtype=torch.int64).npu()
|
||||
self.moe_load = torch.zeros(self.local_num_experts, dtype=torch.int64).npu()
|
||||
|
||||
self.moe_config.num_experts = self.global_num_experts
|
||||
self.moe_config.num_local_experts = self.local_num_experts
|
||||
@@ -225,14 +220,12 @@ class AscendFusedMoE(FusedMoE):
|
||||
moe_quant_params = {
|
||||
"num_experts": self.local_num_experts,
|
||||
"hidden_size": self.hidden_size,
|
||||
"intermediate_size_per_partition":
|
||||
self.intermediate_size_per_partition,
|
||||
"intermediate_size_per_partition": self.intermediate_size_per_partition,
|
||||
"params_dtype": self.params_dtype,
|
||||
"weight_loader": self.weight_loader,
|
||||
}
|
||||
# need full intermediate size pre-sharding for WNA16 act order
|
||||
if (self.quant_method.__class__.__name__
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
if self.quant_method.__class__.__name__ in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod"):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
@@ -243,15 +236,14 @@ class AscendFusedMoE(FusedMoE):
|
||||
|
||||
def _get_quant_type(self) -> QuantType:
|
||||
quant_method = self.quant_method
|
||||
if not hasattr(quant_method,
|
||||
"quant_method") or quant_method.quant_method is None:
|
||||
if not hasattr(quant_method, "quant_method") or quant_method.quant_method is None:
|
||||
return QuantType.NONE
|
||||
|
||||
method = quant_method.quant_method
|
||||
|
||||
if hasattr(method, "quant_type"):
|
||||
from vllm_ascend.quantization.methods.base import \
|
||||
QuantType as SchemeQuantType
|
||||
from vllm_ascend.quantization.methods.base import QuantType as SchemeQuantType
|
||||
|
||||
scheme_quant_type = method.quant_type
|
||||
if scheme_quant_type == SchemeQuantType.W8A8:
|
||||
return QuantType.W8A8
|
||||
@@ -270,22 +262,18 @@ class AscendFusedMoE(FusedMoE):
|
||||
if self.moe_load is not None:
|
||||
self.moe_load.zero_()
|
||||
|
||||
def maybe_all_reduce_tensor_model_parallel(
|
||||
self, final_hidden_states: torch.Tensor):
|
||||
def maybe_all_reduce_tensor_model_parallel(self, final_hidden_states: torch.Tensor):
|
||||
"""NOTE(Yizhou): This is to override the parent class method. In `mc2commimpl`,
|
||||
and `alltoallcommimpl`, we do not need to all-reduce the final outputs since
|
||||
the outputs are already aggregated across tensor parallel ranks in the
|
||||
`finalize` function. In `allgathercommimpl`, we still need to all-reduce the
|
||||
outputs since each rank only has partial outputs.
|
||||
"""
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(
|
||||
final_hidden_states)
|
||||
return torch.ops.vllm.maybe_all_reduce_tensor_model_parallel(final_hidden_states)
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
return_with_event: bool = False) -> torch.Tensor | FusedMoEResult:
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, return_with_event: bool = False
|
||||
) -> torch.Tensor | FusedMoEResult:
|
||||
assert self.quant_method is not None
|
||||
|
||||
forward_context = get_forward_context()
|
||||
@@ -301,15 +289,16 @@ class AscendFusedMoE(FusedMoE):
|
||||
fc3_context = get_flash_common3_context()
|
||||
assert fc3_context is not None
|
||||
AscendFusedMoE.gate_stream.wait_stream(torch.npu.current_stream())
|
||||
with npu_stream_switch(AscendFusedMoE.gate_stream,
|
||||
enabled=self.multistream_overlap_gate):
|
||||
with npu_stream_switch(AscendFusedMoE.gate_stream, enabled=self.multistream_overlap_gate):
|
||||
# share_expert
|
||||
assert fc3_context.shared_experts is not None
|
||||
shared_out = fc3_context.shared_experts(hidden_states)
|
||||
# NOTE: This is exactly the opposite of `maybe_all_reduce_tensor_model_parallel`
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
set_flash_common3_context(shared_out=shared_out)
|
||||
|
||||
@@ -325,24 +314,22 @@ class AscendFusedMoE(FusedMoE):
|
||||
scoring_func=self.scoring_func,
|
||||
routed_scaling_factor=self.routed_scaling_factor,
|
||||
e_score_correction_bias=self.e_score_correction_bias,
|
||||
global_num_experts=self.global_num_experts)
|
||||
global_num_experts=self.global_num_experts,
|
||||
)
|
||||
|
||||
if isinstance(forward_context.moe_comm_method,
|
||||
AllGatherCommImpl):
|
||||
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
topk_weights, True, True)
|
||||
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
topk_ids, True, True)
|
||||
if isinstance(forward_context.moe_comm_method, AllGatherCommImpl):
|
||||
topk_weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_weights, True, True)
|
||||
topk_ids = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(topk_ids, True, True)
|
||||
|
||||
set_flash_common3_context(topk_weights=topk_weights,
|
||||
topk_ids=topk_ids)
|
||||
set_flash_common3_context(topk_weights=topk_weights, topk_ids=topk_ids)
|
||||
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = forward_context.moe_comm_method.prepare(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
replace_allreduce=forward_context.sp_enabled,
|
||||
enable_shared_expert_dp=self.enable_shared_expert_dp,
|
||||
quant_type=self.quant_type)
|
||||
quant_type=self.quant_type,
|
||||
)
|
||||
|
||||
# Make sure the default stream waits for the gate stream to finish.
|
||||
if self.multistream_overlap_gate:
|
||||
@@ -375,39 +362,45 @@ class AscendFusedMoE(FusedMoE):
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
log2phy=self.log2phy,
|
||||
global_redundant_expert_num=self.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask)
|
||||
mc2_mask=mc2_mask,
|
||||
)
|
||||
|
||||
if self.dynamic_eplb:
|
||||
expert_tokens = fused_experts_results.expert_tokens
|
||||
group_list_type = fused_experts_results.group_list_type
|
||||
assert expert_tokens is not None and group_list_type is not None, \
|
||||
assert expert_tokens is not None and group_list_type is not None, (
|
||||
"expert_tokens and group_list_type should not be None when dynamic_eplb is enabled."
|
||||
local_load = expert_tokens if group_list_type == 1 else \
|
||||
torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
)
|
||||
local_load = (
|
||||
expert_tokens
|
||||
if group_list_type == 1
|
||||
else torch.cat([expert_tokens[:1], expert_tokens[1:] - expert_tokens[:-1]])
|
||||
)
|
||||
self.moe_load.add_(local_load)
|
||||
routed_out = forward_context.moe_comm_method.finalize(
|
||||
hidden_states=fused_experts_results.routed_out,
|
||||
reduce_results=self.reduce_results,
|
||||
context_metadata=context_metadata)
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
if return_with_event:
|
||||
return FusedMoEResult(
|
||||
routed_out=routed_out,
|
||||
before_dispatch_evt=fused_experts_results.before_dispatch_evt,
|
||||
before_combine_evt=fused_experts_results.before_combine_evt)
|
||||
before_combine_evt=fused_experts_results.before_combine_evt,
|
||||
)
|
||||
else:
|
||||
# The vLLM FusedMoE forward_impl does not return events.
|
||||
return routed_out
|
||||
|
||||
|
||||
class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shared_experts: torch.nn.Module,
|
||||
gate: Optional[torch.nn.Module] = None,
|
||||
gate: torch.nn.Module | None = None,
|
||||
use_overlapped: bool = True,
|
||||
routed_input_transform: Optional[torch.nn.Module] = None,
|
||||
routed_input_transform: torch.nn.Module | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
AscendFusedMoE.__init__(self, **kwargs)
|
||||
@@ -418,16 +411,12 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
self.use_overlapped = use_overlapped
|
||||
self.shared_expert_stream = None
|
||||
ascend_config = get_ascend_config()
|
||||
self.multistream_overlap_shared_expert = \
|
||||
ascend_config.multistream_overlap_shared_expert and \
|
||||
self._shared_experts is not None
|
||||
self.multistream_overlap_gate = \
|
||||
ascend_config.multistream_overlap_gate and \
|
||||
self._shared_experts is not None
|
||||
self.multistream_overlap_shared_expert = (
|
||||
ascend_config.multistream_overlap_shared_expert and self._shared_experts is not None
|
||||
)
|
||||
self.multistream_overlap_gate = ascend_config.multistream_overlap_gate and self._shared_experts is not None
|
||||
if enable_sp():
|
||||
logger.info_once(
|
||||
"Sequence parallelism is enabled, shared experts are replicated for best performance."
|
||||
)
|
||||
logger.info_once("Sequence parallelism is enabled, shared experts are replicated for best performance.")
|
||||
|
||||
self._gate = gate
|
||||
|
||||
@@ -447,20 +436,15 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
self.quant_method.process_weights_after_loading = wrapped_process_weights # type: ignore
|
||||
|
||||
def _shared_experts_part1(self, hidden_states: torch.Tensor):
|
||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(
|
||||
hidden_states) # type: ignore
|
||||
shared_gate_up, _ = self._shared_experts.gate_up_proj(hidden_states) # type: ignore
|
||||
return shared_gate_up
|
||||
|
||||
def _shared_experts_part2(self, hidden_states: torch.Tensor,
|
||||
shared_gate_up: torch.Tensor):
|
||||
shared_act = self._shared_experts.act_fn(
|
||||
shared_gate_up) # type: ignore
|
||||
shared_out, _ = self._shared_experts.down_proj(
|
||||
shared_act) # type: ignore
|
||||
def _shared_experts_part2(self, hidden_states: torch.Tensor, shared_gate_up: torch.Tensor):
|
||||
shared_act = self._shared_experts.act_fn(shared_gate_up) # type: ignore
|
||||
shared_out, _ = self._shared_experts.down_proj(shared_act) # type: ignore
|
||||
|
||||
# Qwen3-Next specific gating mechanism
|
||||
if hasattr(self._shared_experts, "expert_gate") and \
|
||||
self._shared_experts.expert_gate is not None:
|
||||
if hasattr(self._shared_experts, "expert_gate") and self._shared_experts.expert_gate is not None:
|
||||
gate_out, _ = self._shared_experts.expert_gate(hidden_states) # type: ignore
|
||||
shared_out = F.sigmoid(gate_out) * shared_out
|
||||
return shared_out
|
||||
@@ -468,9 +452,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
def _validate_shared_expert_consistency(self):
|
||||
"""Validate that split shared expert computation matches integrated
|
||||
computation."""
|
||||
test_input = torch.rand(
|
||||
10, self.hidden_size, device='npu', dtype=self.moe_config.in_dtype
|
||||
) * 2 - 1 # Random input for testing, scoped to [-1, 1]
|
||||
test_input = (
|
||||
torch.rand(10, self.hidden_size, device="npu", dtype=self.moe_config.in_dtype) * 2 - 1
|
||||
) # Random input for testing, scoped to [-1, 1]
|
||||
|
||||
integrated_out = self._shared_experts(test_input)
|
||||
part1_out = self._shared_experts_part1(test_input)
|
||||
@@ -478,25 +462,19 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
|
||||
if not torch.allclose(integrated_out, split_out):
|
||||
diff = (integrated_out - split_out).abs()
|
||||
logger.error(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.error("SharedFusedMoE shared experts split computation does not match the integrated computation.")
|
||||
logger.error(f"Max absolute difference: {diff.max().item()}")
|
||||
logger.error("Integrated output - sum: %s, norm: %s",
|
||||
integrated_out.sum().item(),
|
||||
integrated_out.norm().item())
|
||||
logger.error("Split output - sum: %s, norm: %s",
|
||||
split_out.sum().item(),
|
||||
split_out.norm().item())
|
||||
logger.error(
|
||||
"Integrated output - sum: %s, norm: %s", integrated_out.sum().item(), integrated_out.norm().item()
|
||||
)
|
||||
logger.error("Split output - sum: %s, norm: %s", split_out.sum().item(), split_out.norm().item())
|
||||
raise ValueError(
|
||||
"SharedFusedMoE shared experts split computation does not "
|
||||
"match the integrated computation.")
|
||||
logger.info_once(
|
||||
"SharedFusedMoE shared experts split computation matches the "
|
||||
"integrated computation.")
|
||||
"SharedFusedMoE shared experts split computation does not match the integrated computation."
|
||||
)
|
||||
logger.info_once("SharedFusedMoE shared experts split computation matches the integrated computation.")
|
||||
|
||||
@property
|
||||
def gate(self) -> Optional[torch.nn.Module]:
|
||||
def gate(self) -> torch.nn.Module | None:
|
||||
return self._gate if self.use_overlapped else None
|
||||
|
||||
@property
|
||||
@@ -530,8 +508,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
)
|
||||
return shared_out, fused_out
|
||||
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor,
|
||||
fused_moe_evts: FusedMoEEvents):
|
||||
def _forward_shared_experts(self, hidden_states: torch.Tensor, fused_moe_evts: FusedMoEEvents):
|
||||
if self._shared_experts is None:
|
||||
return None
|
||||
|
||||
@@ -539,11 +516,9 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
if evt is not None:
|
||||
torch.npu.current_stream().wait_event(evt)
|
||||
|
||||
with npu_stream_switch(shared_experts_calculation_stream(),
|
||||
enabled=self.multistream_overlap_shared_expert):
|
||||
with npu_stream_switch(shared_experts_calculation_stream(), enabled=self.multistream_overlap_shared_expert):
|
||||
# Ensure the shared experts wait for hidden_states to be ready.
|
||||
torch.npu.current_stream().wait_event(
|
||||
fused_moe_evts.before_routed_experts)
|
||||
torch.npu.current_stream().wait_event(fused_moe_evts.before_routed_experts)
|
||||
# Execute the gate projection and activation concurrently with the
|
||||
# dispatch communication.
|
||||
maybe_wait_event(fused_moe_evts.before_dispatch)
|
||||
@@ -556,20 +531,22 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
# Make sure the default stream waits for the shared experts stream to
|
||||
# finish.
|
||||
if self.multistream_overlap_shared_expert:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
shared_experts_calculation_stream())
|
||||
torch.npu.current_stream().wait_stream(shared_experts_calculation_stream())
|
||||
|
||||
# NOTE: This is exactly the opposite of
|
||||
# `maybe_all_reduce_tensor_model_parallel`
|
||||
forward_context = get_forward_context()
|
||||
moe_comm_type = forward_context.moe_comm_type
|
||||
if moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2} \
|
||||
and not shared_expert_dp_enabled():
|
||||
if (
|
||||
moe_comm_type in {MoECommType.ALLTOALL, MoECommType.MC2, MoECommType.FUSED_MC2}
|
||||
and not shared_expert_dp_enabled()
|
||||
):
|
||||
shared_out = tensor_model_parallel_all_reduce(shared_out)
|
||||
return shared_out
|
||||
|
||||
def forward_impl( # type: ignore[override]
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor
|
||||
):
|
||||
if self.multistream_overlap_gate:
|
||||
set_flash_common3_context(shared_experts=self._shared_experts)
|
||||
|
||||
@@ -596,6 +573,7 @@ class AscendSharedFusedMoE(SharedFusedMoE, AscendFusedMoE):
|
||||
before_routed_experts=before_routed_experts,
|
||||
before_dispatch=fused_moe_results.before_dispatch_evt,
|
||||
before_combine=fused_moe_results.before_combine_evt,
|
||||
))
|
||||
),
|
||||
)
|
||||
|
||||
return shared_out, routed_out
|
||||
|
||||
@@ -17,7 +17,6 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Optional
|
||||
|
||||
import torch
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -27,18 +26,24 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.ops.fused_moe.moe_mlp import unified_apply_mlp
|
||||
from vllm_ascend.ops.fused_moe.prepare_finalize import (
|
||||
PrepareAndFinalize, PrepareAndFinalizeWithAll2All,
|
||||
PrepareAndFinalizeWithAllGather, PrepareAndFinalizeWithMC2, QuantType)
|
||||
PrepareAndFinalize,
|
||||
PrepareAndFinalizeWithAll2All,
|
||||
PrepareAndFinalizeWithAllGather,
|
||||
PrepareAndFinalizeWithMC2,
|
||||
QuantType,
|
||||
)
|
||||
from vllm_ascend.ops.fused_moe.token_dispatcher import (
|
||||
MoETokenDispatcher, TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather, TokenDispatcherWithMC2)
|
||||
MoETokenDispatcher,
|
||||
TokenDispatcherWithAll2AllV,
|
||||
TokenDispatcherWithAllGather,
|
||||
TokenDispatcherWithMC2,
|
||||
)
|
||||
|
||||
_MoECommMethods: Dict[Optional[MoECommType], MoECommMethod] = {}
|
||||
_MoECommMethods: dict[MoECommType | None, MoECommMethod] = {}
|
||||
|
||||
|
||||
def get_moe_comm_method(
|
||||
moe_comm_type: Optional[MoECommType]) -> Optional[MoECommMethod]:
|
||||
return _MoECommMethods.get(moe_comm_type, None)
|
||||
def get_moe_comm_method(moe_comm_type: MoECommType | None) -> MoECommMethod | None:
|
||||
return _MoECommMethods.get(moe_comm_type)
|
||||
|
||||
|
||||
def setup_moe_comm_method(moe_config):
|
||||
@@ -50,6 +55,7 @@ def setup_moe_comm_method(moe_config):
|
||||
|
||||
def set_gmmswigluquant_method():
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
return ascend_config.ascend_fusion_config.fusion_ops_gmmswigluquant
|
||||
|
||||
@@ -84,51 +90,46 @@ class MoECommMethod(ABC):
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
hidden_states, router_logits, mc2_mask, context_metadata = self.prepare_finalize.prepare(
|
||||
hidden_states, router_logits, enable_shared_expert_dp,
|
||||
replace_allreduce, quant_type)
|
||||
hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce, quant_type
|
||||
)
|
||||
return hidden_states, router_logits, mc2_mask, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states,
|
||||
reduce_results,
|
||||
context_metadata)
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
hidden_states = self.prepare_finalize.finalize(hidden_states, reduce_results, context_metadata)
|
||||
return hidden_states
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
# Check constraints
|
||||
assert hidden_states.dtype in [
|
||||
torch.float32, torch.float16, torch.bfloat16, torch.int8
|
||||
]
|
||||
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16, torch.int8]
|
||||
|
||||
moe_comm_method = get_forward_context().moe_comm_method
|
||||
assert moe_comm_method is not None, "Missing communication context"
|
||||
@@ -143,13 +144,13 @@ class MoECommMethod(ABC):
|
||||
topk_weights=topk_weights,
|
||||
topk_ids=topk_ids,
|
||||
expert_map=expert_map,
|
||||
global_redundant_expert_num=self.moe_config.
|
||||
global_redundant_expert_num,
|
||||
global_redundant_expert_num=self.moe_config.global_redundant_expert_num,
|
||||
mc2_mask=mc2_mask,
|
||||
apply_router_weight_on_input=apply_router_weight_on_input,
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
pertoken_scale=pertoken_scale)
|
||||
pertoken_scale=pertoken_scale,
|
||||
)
|
||||
|
||||
mlp_output = unified_apply_mlp(
|
||||
hidden_states=dispatch_results.hidden_states,
|
||||
@@ -168,29 +169,29 @@ class MoECommMethod(ABC):
|
||||
with_quant=use_int8_w8a8 or use_int4_w4a8 or use_int4_w4a16,
|
||||
fusion=use_int8_w8a8 and self.use_fusion_ops,
|
||||
need_trans=need_trans,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
|
||||
before_combine_evt = torch.npu.current_stream().record_event()
|
||||
combine_results = self.token_dispatcher.token_combine(
|
||||
hidden_states=mlp_output,
|
||||
context_metadata=dispatch_results.context_metadata)
|
||||
hidden_states=mlp_output, context_metadata=dispatch_results.context_metadata
|
||||
)
|
||||
|
||||
return FusedExpertsResult(
|
||||
routed_out=combine_results.routed_out,
|
||||
before_dispatch_evt=before_dispatch_evt,
|
||||
before_combine_evt=before_combine_evt,
|
||||
group_list_type=dispatch_results.group_list_type,
|
||||
expert_tokens=dispatch_results.group_list)
|
||||
expert_tokens=dispatch_results.group_list,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def _get_token_dispatcher(self) -> MoETokenDispatcher:
|
||||
raise NotImplementedError(
|
||||
"_get_token_dispatcher function not implemented.")
|
||||
raise NotImplementedError("_get_token_dispatcher function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def _get_prepare_finalize(self) -> PrepareAndFinalize:
|
||||
raise NotImplementedError(
|
||||
"_get_prepare_finalize function not implemented.")
|
||||
raise NotImplementedError("_get_prepare_finalize function not implemented.")
|
||||
|
||||
|
||||
class AllGatherCommImpl(MoECommMethod):
|
||||
@@ -216,7 +217,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
return TokenDispatcherWithAllGather(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAllGather(self.moe_config)
|
||||
@@ -227,7 +229,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
@@ -253,7 +255,8 @@ class AlltoAllCommImpl(MoECommMethod):
|
||||
return TokenDispatcherWithAll2AllV(
|
||||
top_k=self.moe_config.experts_per_token,
|
||||
num_experts=self.moe_config.num_experts,
|
||||
num_local_experts=self.moe_config.num_local_experts)
|
||||
num_local_experts=self.moe_config.num_local_experts,
|
||||
)
|
||||
|
||||
def _get_prepare_finalize(self):
|
||||
return PrepareAndFinalizeWithAll2All(self.moe_config)
|
||||
@@ -264,7 +267,7 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_moe_distribute_dispatch` and `npu_moe_distribute_combine` are available.
|
||||
3. `enable_expert_parallel=False` is not supported.
|
||||
|
||||
|
||||
This implementation uses the MC2 communication method, which is optimized for
|
||||
Communication and Computation parallelism on Ascend devices.
|
||||
"""
|
||||
@@ -276,36 +279,36 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
return PrepareAndFinalizeWithMC2(self.moe_config)
|
||||
|
||||
def fused_experts(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
assert not (
|
||||
w1_scale is None or w2_scale is None
|
||||
), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
activation: str = "silu",
|
||||
apply_router_weight_on_input: bool = False,
|
||||
use_int8_w8a8: bool = False,
|
||||
use_int4_w4a8: bool = False,
|
||||
use_int4_w4a16: bool = False,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
# For load balance
|
||||
log2phy: torch.Tensor = None,
|
||||
need_trans: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
mc2_mask: torch.Tensor = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
assert not (w1_scale is None or w2_scale is None), "w1_scale and w2_scale cannot be None for FusedMC2CommImpl."
|
||||
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), \
|
||||
assert isinstance(self.token_dispatcher, TokenDispatcherWithMC2), (
|
||||
"token_dispatcher must be an instance of TokenDispatcherWithMC2."
|
||||
)
|
||||
|
||||
# Apply log2phy if needed
|
||||
if log2phy is not None:
|
||||
@@ -346,10 +349,8 @@ class FusedMC2CommImpl(MoECommMethod):
|
||||
ep_rank_size=self.token_dispatcher.ep_world_size,
|
||||
ep_rank_id=self.token_dispatcher.ep_rank_id,
|
||||
moe_expert_num=self.moe_config.num_experts,
|
||||
global_bs=self.token_dispatcher.global_bs)
|
||||
global_bs=self.token_dispatcher.global_bs,
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out,
|
||||
group_list_type=group_list_type,
|
||||
expert_tokens=expert_tokens)
|
||||
raise ValueError(f"Wrong value of {envs_ascend.VLLM_ASCEND_ENABLE_FUSED_MC2=}")
|
||||
return FusedExpertsResult(routed_out=out, group_list_type=group_list_type, expert_tokens=expert_tokens)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -23,24 +22,22 @@ from vllm.forward_context import get_forward_context
|
||||
from vllm.triton_utils import HAS_TRITON
|
||||
|
||||
from vllm_ascend.ascend_forward_context import MoECommType
|
||||
from vllm_ascend.utils import (AscendDeviceType, dispose_tensor,
|
||||
enable_custom_op, get_ascend_device_type,
|
||||
get_weight_prefetch_method)
|
||||
from vllm_ascend.utils import (
|
||||
dispose_tensor,
|
||||
enable_custom_op,
|
||||
get_weight_prefetch_method,
|
||||
)
|
||||
|
||||
|
||||
def _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
return fusion and dynamic_eplb and enable_custom_op()
|
||||
|
||||
|
||||
def cumsum_group_list(group_list: torch.Tensor,
|
||||
src_list_type: int,
|
||||
dst_list_type: int,
|
||||
active_num: int = 0,
|
||||
expert_num: int = 0) -> torch.Tensor:
|
||||
def cumsum_group_list(
|
||||
group_list: torch.Tensor, src_list_type: int, dst_list_type: int, active_num: int = 0, expert_num: int = 0
|
||||
) -> torch.Tensor:
|
||||
if src_list_type not in [0, 1, 2]:
|
||||
raise ValueError(
|
||||
f"group_list_type should be in [0, 1, 2], but received {src_list_type}"
|
||||
)
|
||||
raise ValueError(f"group_list_type should be in [0, 1, 2], but received {src_list_type}")
|
||||
|
||||
if src_list_type == dst_list_type:
|
||||
return group_list
|
||||
@@ -53,10 +50,9 @@ def cumsum_group_list(group_list: torch.Tensor,
|
||||
if src_list_type == 2 and dst_list_type == 0:
|
||||
experts = pad(group_list[:, 0], (1, 0))
|
||||
tokens = pad(group_list[:, 1].cumsum(dim=0), (1, 0))
|
||||
cumsum_group_list = torch.full(size=(expert_num, ),
|
||||
fill_value=active_num,
|
||||
dtype=group_list.dtype,
|
||||
device=group_list.device)
|
||||
cumsum_group_list = torch.full(
|
||||
size=(expert_num,), fill_value=active_num, dtype=group_list.dtype, device=group_list.device
|
||||
)
|
||||
|
||||
for i, (start, end) in enumerate(zip(experts[:-1], experts[1:])):
|
||||
if end > start:
|
||||
@@ -65,30 +61,32 @@ def cumsum_group_list(group_list: torch.Tensor,
|
||||
return cumsum_group_list
|
||||
raise NotImplementedError(
|
||||
f"Conversion from src_list_type={src_list_type} to dst_list_type={dst_list_type} is not implemented yet. "
|
||||
"This feature is under development.")
|
||||
"This feature is under development."
|
||||
)
|
||||
|
||||
|
||||
def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: list[torch.Tensor],
|
||||
w1_scale: list[torch.Tensor],
|
||||
w2: list[torch.Tensor],
|
||||
w2_scale: list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
def quant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: list[torch.Tensor],
|
||||
w1_scale: list[torch.Tensor],
|
||||
w2: list[torch.Tensor],
|
||||
w2_scale: list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
fusion: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if w1_offset is not None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
quantized_hidden_states = None
|
||||
elif dynamic_scale is None:
|
||||
unquantized_hidden_states = hidden_states
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
# Dispose the original unquantized hidden states
|
||||
# to save npu memory because they're no longer used.
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
@@ -103,22 +101,18 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
|
||||
weight_prefetch_method = get_weight_prefetch_method()
|
||||
if weight_prefetch_method:
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(
|
||||
hidden_states)
|
||||
weight_prefetch_method.maybe_prefetch_moe_weight_postprocess(hidden_states)
|
||||
is_mc2 = get_forward_context().moe_comm_type == MoECommType.MC2
|
||||
if w1_scale_bias is None and w1_offset is None and is_mc2:
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = (
|
||||
torch.ops._C_ascend.
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type,
|
||||
0),
|
||||
))
|
||||
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
)
|
||||
elif fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
@@ -126,7 +120,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
weight=w1[0],
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
x_scale=pertoken_scale)
|
||||
x_scale=pertoken_scale,
|
||||
)
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
else:
|
||||
@@ -140,7 +135,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=torch.int32)[0]
|
||||
output_dtype=torch.int32,
|
||||
)[0]
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
# act_fn: swiglu
|
||||
@@ -165,7 +161,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale[0].dtype)[0]
|
||||
output_dtype=w2_scale[0].dtype,
|
||||
)[0]
|
||||
elif w1_offset is not None:
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
@@ -177,7 +174,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
output_dtype=_output_dtype,
|
||||
)[0]
|
||||
dispose_tensor(unquantized_hidden_states)
|
||||
# act_fn: swiglu
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
@@ -191,13 +189,12 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
output_dtype=_output_dtype,
|
||||
)[0]
|
||||
else:
|
||||
if w1_scale_bias is not None:
|
||||
if group_list_type == 0:
|
||||
group_list = torch.cat(
|
||||
[group_list[:1],
|
||||
torch.diff(group_list, dim=0)])
|
||||
group_list = torch.cat([group_list[:1], torch.diff(group_list, dim=0)])
|
||||
group_list_type = 1
|
||||
bias1 = [w1_scale_bias] if not fusion else w1_scale_bias
|
||||
bias2 = [w2_scale_bias]
|
||||
@@ -206,17 +203,14 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
|
||||
if _custom_gmm_swiglu_enabled(fusion, dynamic_eplb):
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = (
|
||||
torch.ops._C_ascend.
|
||||
grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type,
|
||||
0),
|
||||
bias=bias1,
|
||||
))
|
||||
hidden_states, swiglu_out_scale, _ = torch.ops._C_ascend.grouped_matmul_swiglu_quant_weight_nz_tensor_list(
|
||||
x=hidden_states,
|
||||
weight=w1,
|
||||
weight_scale=w1_scale,
|
||||
x_scale=pertoken_scale,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
bias=bias1,
|
||||
)
|
||||
elif fusion and not dynamic_eplb:
|
||||
# gmm1: gate_up_proj & act_fn: swiglu
|
||||
hidden_states, swiglu_out_scale, _ = torch_npu.npu_grouped_matmul_swiglu_quant(
|
||||
@@ -225,7 +219,8 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
bias=bias1,
|
||||
group_list=cumsum_group_list(group_list, group_list_type, 0),
|
||||
weight_scale=w1_scale[0],
|
||||
x_scale=pertoken_scale)
|
||||
x_scale=pertoken_scale,
|
||||
)
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
else:
|
||||
@@ -241,21 +236,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
output_dtype=_output_dtype,
|
||||
)[0]
|
||||
if quantized_hidden_states is not None:
|
||||
dispose_tensor(quantized_hidden_states)
|
||||
# act_fn: swiglu
|
||||
if HAS_TRITON:
|
||||
from vllm_ascend.ops.triton.activation.swiglu_quant import \
|
||||
swiglu_quant
|
||||
from vllm_ascend.ops.triton.activation.swiglu_quant import swiglu_quant
|
||||
|
||||
hidden_states, swiglu_out_scale = swiglu_quant(
|
||||
hidden_states,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type)
|
||||
hidden_states, group_list=group_list, group_list_type=group_list_type
|
||||
)
|
||||
else:
|
||||
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
hidden_states, swiglu_out_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
# gmm2: down_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -267,18 +261,20 @@ def quant_apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list_type=group_list_type,
|
||||
group_type=0,
|
||||
group_list=group_list,
|
||||
output_dtype=_output_dtype)[0]
|
||||
output_dtype=_output_dtype,
|
||||
)[0]
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
need_trans: bool = True) -> torch.Tensor:
|
||||
|
||||
def unquant_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
group_list_type: int = 1,
|
||||
topk_scales: torch.Tensor | None = None,
|
||||
need_trans: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if need_trans:
|
||||
w1 = w1.transpose(1, 2)
|
||||
w2 = w2.transpose(1, 2)
|
||||
@@ -307,44 +303,50 @@ def unquant_apply_mlp(hidden_states: torch.Tensor,
|
||||
return hidden_states
|
||||
|
||||
|
||||
def unified_apply_mlp(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
w1_scale: Optional[list[torch.Tensor]] = None,
|
||||
w2_scale: Optional[list[torch.Tensor]] = None,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: Optional[torch.Tensor] = None,
|
||||
w2_offset: Optional[torch.Tensor] = None,
|
||||
topk_scales: Optional[torch.Tensor] = None,
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False) -> torch.Tensor:
|
||||
def unified_apply_mlp(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor | list[torch.Tensor],
|
||||
w2: torch.Tensor | list[torch.Tensor],
|
||||
group_list: torch.Tensor,
|
||||
w1_scale: list[torch.Tensor] | None = None,
|
||||
w2_scale: list[torch.Tensor] | None = None,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
w1_scale_bias: torch.Tensor = None,
|
||||
w2_scale_bias: torch.Tensor = None,
|
||||
w1_offset: torch.Tensor | None = None,
|
||||
w2_offset: torch.Tensor | None = None,
|
||||
topk_scales: torch.Tensor | None = None,
|
||||
with_quant: bool = False,
|
||||
fusion: bool = False,
|
||||
need_trans: bool = True,
|
||||
dynamic_eplb: bool = False,
|
||||
) -> torch.Tensor:
|
||||
if with_quant:
|
||||
assert w1_scale is not None and w2_scale is not None
|
||||
return quant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb)
|
||||
return quant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w1_scale=w1_scale,
|
||||
w2=w2,
|
||||
w2_scale=w2_scale,
|
||||
group_list=group_list,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list_type=group_list_type,
|
||||
w1_scale_bias=w1_scale_bias,
|
||||
w2_scale_bias=w2_scale_bias,
|
||||
w1_offset=w1_offset,
|
||||
w2_offset=w2_offset,
|
||||
fusion=fusion,
|
||||
dynamic_eplb=dynamic_eplb,
|
||||
)
|
||||
else:
|
||||
return unquant_apply_mlp(hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans)
|
||||
return unquant_apply_mlp(
|
||||
hidden_states=hidden_states,
|
||||
w1=w1,
|
||||
w2=w2,
|
||||
group_list=group_list,
|
||||
group_list_type=group_list_type,
|
||||
topk_scales=topk_scales,
|
||||
need_trans=need_trans,
|
||||
)
|
||||
|
||||
@@ -16,22 +16,23 @@
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_dp_group, get_pcp_group, get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size)
|
||||
get_dp_group,
|
||||
get_pcp_group,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.fused_moe import FusedMoEConfig
|
||||
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.utils import fc3_all_gather_and_maybe_unpad_impl
|
||||
from vllm_ascend.utils import (enable_sp, npu_stream_switch,
|
||||
prefill_context_parallel_enable)
|
||||
from vllm_ascend.utils import enable_sp, npu_stream_switch, prefill_context_parallel_enable
|
||||
|
||||
|
||||
class QuantType(Enum):
|
||||
@@ -51,7 +52,8 @@ class PrepareAndFinalize(ABC):
|
||||
moe_config (FusedMoEConfig): Configuration object containing TP/DP/EP group info,
|
||||
sizes, ranks, and communication settings.
|
||||
"""
|
||||
quant_stream: Optional[torch.npu.Stream] = None
|
||||
|
||||
quant_stream: torch.npu.Stream | None = None
|
||||
|
||||
def __init__(self, moe_config: FusedMoEConfig):
|
||||
self.moe_config = moe_config
|
||||
@@ -67,9 +69,8 @@ class PrepareAndFinalize(ABC):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type: QuantType = QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type: QuantType = QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Prepare tensors before MoE computation. May involve:
|
||||
- Padding to align communication boundaries
|
||||
@@ -92,10 +93,9 @@ class PrepareAndFinalize(ABC):
|
||||
"""
|
||||
raise NotImplementedError("Prepare not implemented.")
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalize MoE output. May involve:
|
||||
- Gathering sliced tensors across TP ranks
|
||||
@@ -135,9 +135,8 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Pad hidden_states and router_logits to next multiple of TP size.
|
||||
@@ -158,33 +157,24 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
pad_size = self.tp_size - self.num_tokens # Pad to TP size (cyclic)
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
context_metadata = {
|
||||
"padded_hidden_states_shape": padded_hidden_states_shape
|
||||
}
|
||||
context_metadata = {"padded_hidden_states_shape": padded_hidden_states_shape}
|
||||
|
||||
return hidden_states, router_logits, None, context_metadata
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If TP > 1, all-gather slices to reconstruct full tensor.
|
||||
@@ -201,20 +191,16 @@ class PrepareAndFinalizeWithAll2All(PrepareAndFinalize):
|
||||
# may share memory with original hidden_states. Since shared
|
||||
# experts may use the original tensor, reusing it would cause
|
||||
# in-place modification during all_gather, corrupting the data.
|
||||
padded_hidden_states_shape = context_metadata[
|
||||
"padded_hidden_states_shape"]
|
||||
padded_hidden_states_shape = context_metadata["padded_hidden_states_shape"]
|
||||
gathered_hidden_states = torch.empty(
|
||||
padded_hidden_states_shape,
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype)
|
||||
split_hidden_states = torch.tensor_split(
|
||||
gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
padded_hidden_states_shape, device=hidden_states.device, dtype=hidden_states.dtype
|
||||
)
|
||||
split_hidden_states = torch.tensor_split(gathered_hidden_states, self.tp_size, dim=0)
|
||||
dist.all_gather(list(split_hidden_states), hidden_states, self.moe_config.tp_group.device_group)
|
||||
hidden_states = gathered_hidden_states
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
hidden_states = hidden_states[: self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
@@ -246,9 +232,8 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch `mc2_mask` and target padding length from forward context.
|
||||
@@ -278,20 +263,14 @@ class PrepareAndFinalizeWithMC2(PrepareAndFinalizeWithAll2All):
|
||||
|
||||
# Pad if necessary (unless shared expert DP is enabled)
|
||||
if pad_size > 0 and not self.enable_shared_expert_dp:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
padded_hidden_states_shape = hidden_states.shape
|
||||
|
||||
# Slice across TP ranks
|
||||
if self.tp_size > 1 and not self.enable_shared_expert_dp:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_hidden_states = torch.tensor_split(hidden_states, self.tp_size, dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits, self.tp_size, dim=0)
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
@@ -330,9 +309,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
AllGather hidden_states and router_logits to form global tensors.
|
||||
@@ -341,46 +319,31 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
Tuple of (global_hidden_states, global_router_logits, None)
|
||||
"""
|
||||
if enable_sp():
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits,
|
||||
quant_type)
|
||||
return self._prepare_with_ep_group(hidden_states, router_logits, quant_type)
|
||||
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits,
|
||||
enable_shared_expert_dp,
|
||||
replace_allreduce)
|
||||
return self._prepare_with_dp_group(hidden_states, router_logits, enable_shared_expert_dp, replace_allreduce)
|
||||
|
||||
def _prepare_with_ep_group(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
self, hidden_states: torch.Tensor, router_logits: torch.Tensor, quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
pertoken_scale = None
|
||||
if quant_type == QuantType.W8A8:
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(
|
||||
hidden_states)
|
||||
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
assert PrepareAndFinalize.quant_stream is not None
|
||||
PrepareAndFinalize.quant_stream.wait_stream(
|
||||
torch.npu.current_stream())
|
||||
with npu_stream_switch(PrepareAndFinalize.quant_stream,
|
||||
enabled=self.multistream_overlap_gate):
|
||||
hidden_states = fc3_all_gather_and_maybe_unpad_impl(
|
||||
hidden_states)
|
||||
PrepareAndFinalize.quant_stream.wait_stream(torch.npu.current_stream())
|
||||
with npu_stream_switch(PrepareAndFinalize.quant_stream, enabled=self.multistream_overlap_gate):
|
||||
hidden_states = fc3_all_gather_and_maybe_unpad_impl(hidden_states)
|
||||
else:
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
router_logits, True, True)
|
||||
hidden_states = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(hidden_states, True, True)
|
||||
router_logits = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(router_logits, True, True)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||
pertoken_scale, True, True)
|
||||
pertoken_scale = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(pertoken_scale, True, True)
|
||||
|
||||
if self.multistream_overlap_gate:
|
||||
torch.npu.current_stream().wait_stream(
|
||||
PrepareAndFinalize.quant_stream)
|
||||
torch.npu.current_stream().wait_stream(PrepareAndFinalize.quant_stream)
|
||||
|
||||
if pertoken_scale is not None:
|
||||
return (hidden_states, pertoken_scale), router_logits, None, None
|
||||
@@ -393,9 +356,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
router_logits: torch.Tensor,
|
||||
enable_shared_expert_dp: bool = False,
|
||||
replace_allreduce: bool = False,
|
||||
quant_type=QuantType.NONE
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor],
|
||||
Optional[torch.Tensor]]:
|
||||
quant_type=QuantType.NONE,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Preparation steps:
|
||||
1. Fetch max token count across DP group from forward context.
|
||||
@@ -413,16 +375,12 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
self.num_tokens = hidden_states.shape[0]
|
||||
pad_size = max_tokens_across_dp - self.num_tokens
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
hidden_states = nn.functional.pad(hidden_states, (0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits, (0, 0, 0, pad_size))
|
||||
|
||||
# All-gather across DP group
|
||||
hidden_states = self.moe_config.dp_group.all_gather(
|
||||
hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(
|
||||
router_logits, 0)
|
||||
hidden_states = self.moe_config.dp_group.all_gather(hidden_states, 0)
|
||||
router_logits = self.moe_config.dp_group.all_gather(router_logits, 0)
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().all_gather(
|
||||
@@ -436,10 +394,9 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
return hidden_states, router_logits, None, None
|
||||
|
||||
def finalize(self,
|
||||
hidden_states: torch.Tensor,
|
||||
reduce_results: bool,
|
||||
context_metadata: Optional[dict] = None) -> torch.Tensor:
|
||||
def finalize(
|
||||
self, hidden_states: torch.Tensor, reduce_results: bool, context_metadata: dict | None = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
Reduce Scatter hidden states.
|
||||
@@ -452,8 +409,7 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
|
||||
return self._finalize_with_dp_group(hidden_states, reduce_results)
|
||||
|
||||
def _finalize_with_ep_group(self,
|
||||
hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
def _finalize_with_ep_group(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Argument `reduce_results` is not needed in this func. Given sequence parallelism is enabled:
|
||||
1. Reduce_results is False usually happens when models have shared experts and need to
|
||||
@@ -463,13 +419,11 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
2 Reduce_results is True usually happens when model has no shared experts. We still do reduce scatter
|
||||
here, then skip allreudce in FusedMoe.
|
||||
"""
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(
|
||||
hidden_states, True)
|
||||
hidden_states = torch.ops.vllm.maybe_pad_and_reduce(hidden_states, True)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
def _finalize_with_dp_group(self, hidden_states: torch.Tensor, reduce_results: bool) -> torch.Tensor:
|
||||
"""
|
||||
Finalization steps:
|
||||
1. If DP > 1 and not shared expert, reduce-scatter output across DP group.
|
||||
@@ -481,9 +435,8 @@ class PrepareAndFinalizeWithAllGather(PrepareAndFinalize):
|
||||
"""
|
||||
if self.moe_config.dp_size > 1 and not self.enable_shared_expert_dp:
|
||||
hidden_states = get_dp_group().reduce_scatter(hidden_states, 0)
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
hidden_states = hidden_states[: self.num_tokens]
|
||||
|
||||
if prefill_context_parallel_enable() and self.moe_config.pcp_size > 1:
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states,
|
||||
dim=0)
|
||||
hidden_states = get_pcp_group().reduce_scatter(hidden_states, dim=0)
|
||||
return hidden_states
|
||||
|
||||
@@ -22,7 +22,6 @@
|
||||
# limitations under the License.
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
@@ -30,10 +29,8 @@ from vllm.config import get_current_vllm_config
|
||||
from vllm.distributed.parallel_state import get_ep_group
|
||||
|
||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import (
|
||||
async_all_to_all, gather_from_sequence_parallel_region)
|
||||
from vllm_ascend.utils import (AscendDeviceType, get_ascend_device_type,
|
||||
is_hierarchical_communication_enabled)
|
||||
from vllm_ascend.ops.fused_moe.comm_utils import async_all_to_all, gather_from_sequence_parallel_region
|
||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, is_hierarchical_communication_enabled
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -52,7 +49,6 @@ class TokenCombineResult:
|
||||
|
||||
|
||||
class MoETokenDispatcher(ABC):
|
||||
|
||||
def __init__(self, **kwargs) -> None:
|
||||
"""
|
||||
Initialize the MoE Token Dispatcher.
|
||||
@@ -79,26 +75,24 @@ class MoETokenDispatcher(ABC):
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
) -> TokenDispatchResult:
|
||||
raise NotImplementedError("Dispatch function not implemented.")
|
||||
|
||||
@abstractmethod
|
||||
def token_combine(self,
|
||||
hidden_states: torch.Tensor,
|
||||
context_metadata: dict,
|
||||
bias: torch.Tensor | None = None) -> TokenCombineResult:
|
||||
def token_combine(
|
||||
self, hidden_states: torch.Tensor, context_metadata: dict, bias: torch.Tensor | None = None
|
||||
) -> TokenCombineResult:
|
||||
raise NotImplementedError("Combine function not implemented.")
|
||||
|
||||
|
||||
class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
device_group = get_mc2_group().device_group
|
||||
@@ -108,10 +102,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
self.ep_rank_id = get_mc2_group().rank_in_group
|
||||
self.ep_world_size = get_mc2_group().world_size
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu,
|
||||
"npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = (
|
||||
get_ascend_device_type() == AscendDeviceType.A3)
|
||||
self.enable_dispatch_v2 = hasattr(torch_npu, "npu_moe_distribute_dispatch_v2")
|
||||
self.need_extra_args = get_ascend_device_type() == AscendDeviceType.A3
|
||||
|
||||
# NOTE: When in A2, setting the environment variables HCCL_INTRA_PCIE_ENABLE=1 and
|
||||
# HCCL_INTRA_ROCE_ENABLE=0 can reduce cross-machine communication traffic and significantly
|
||||
@@ -126,10 +118,8 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
compilation_config = vllm_config.compilation_config
|
||||
speculative_config = vllm_config.speculative_config
|
||||
tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
uniform_decode_query_len = 1 if not speculative_config else \
|
||||
1 + speculative_config.num_speculative_tokens
|
||||
decode_max_num_seqs = getattr(scheduler_config, 'decode_max_num_seqs',
|
||||
0)
|
||||
uniform_decode_query_len = 1 if not speculative_config else 1 + speculative_config.num_speculative_tokens
|
||||
decode_max_num_seqs = getattr(scheduler_config, "decode_max_num_seqs", 0)
|
||||
max_num_reqs = max(scheduler_config.max_num_seqs, decode_max_num_seqs)
|
||||
if compilation_config.cudagraph_capture_sizes:
|
||||
max_num_tokens = compilation_config.max_cudagraph_capture_size
|
||||
@@ -167,44 +157,56 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"ep_rank_id": self.ep_rank_id,
|
||||
}
|
||||
if self.need_extra_args:
|
||||
stage1_kwargs.update({
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
stage1_kwargs.update(
|
||||
{
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
if self.need_expert_scale:
|
||||
stage1_kwargs.update({
|
||||
"expert_scales":
|
||||
topk_weights.to(torch.float32),
|
||||
})
|
||||
stage1_kwargs.update(
|
||||
{
|
||||
"expert_scales": topk_weights.to(torch.float32),
|
||||
}
|
||||
)
|
||||
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
return kwargs_mc2
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(hidden_states, topk_weights,
|
||||
topk_ids, expert_map,
|
||||
mc2_mask,
|
||||
global_redundant_expert_num)
|
||||
output = torch_npu.npu_moe_distribute_dispatch_v2(
|
||||
**kwargs_mc2
|
||||
) if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_dispatch(
|
||||
**kwargs_mc2)
|
||||
kwargs_mc2 = self.get_dispatch_mc2_kwargs(
|
||||
hidden_states, topk_weights, topk_ids, expert_map, mc2_mask, global_redundant_expert_num
|
||||
)
|
||||
output = (
|
||||
torch_npu.npu_moe_distribute_dispatch_v2(**kwargs_mc2)
|
||||
if self.enable_dispatch_v2
|
||||
else torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, assist_info_for_combine, expert_token_nums, \
|
||||
ep_recv_counts, tp_recv_counts, expand_scales = output[0:7]
|
||||
(
|
||||
expand_x,
|
||||
dynamic_scale,
|
||||
assist_info_for_combine,
|
||||
expert_token_nums,
|
||||
ep_recv_counts,
|
||||
tp_recv_counts,
|
||||
expand_scales,
|
||||
) = output[0:7]
|
||||
|
||||
context_metadata = {
|
||||
"topk_ids": topk_ids,
|
||||
@@ -213,18 +215,19 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
"ep_recv_counts": ep_recv_counts,
|
||||
"tp_recv_counts": tp_recv_counts,
|
||||
"assist_info_for_combine": assist_info_for_combine,
|
||||
"expand_scales": expand_scales
|
||||
"expand_scales": expand_scales,
|
||||
}
|
||||
|
||||
group_list_type = 0
|
||||
return TokenDispatchResult(hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata)
|
||||
return TokenDispatchResult(
|
||||
hidden_states=expand_x,
|
||||
dynamic_scale=dynamic_scale,
|
||||
group_list=expert_token_nums,
|
||||
group_list_type=group_list_type,
|
||||
context_metadata=context_metadata,
|
||||
)
|
||||
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict):
|
||||
def get_combine_mc_kwargs(self, hidden_states: torch.Tensor, context_metadata: dict):
|
||||
expert_map = context_metadata["expert_map"]
|
||||
topk_ids = context_metadata["topk_ids"]
|
||||
topk_weights = context_metadata["topk_weights"]
|
||||
@@ -246,9 +249,7 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
}
|
||||
|
||||
if self.with_quant:
|
||||
tp_recv_counts = torch.empty(1,
|
||||
dtype=torch.int32,
|
||||
device=hidden_states.device)
|
||||
tp_recv_counts = torch.empty(1, dtype=torch.int32, device=hidden_states.device)
|
||||
|
||||
stage3_kwargs = {
|
||||
"ep_send_counts": ep_recv_counts,
|
||||
@@ -264,12 +265,14 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
stage3_kwargs["expand_idx"] = assist_info_for_combine
|
||||
|
||||
if self.need_extra_args:
|
||||
stage3_kwargs.update({
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
})
|
||||
stage3_kwargs.update(
|
||||
{
|
||||
"tp_send_counts": tp_recv_counts,
|
||||
"group_tp": self.moe_all_to_all_group_name,
|
||||
"tp_world_size": 1,
|
||||
"tp_rank_id": 0,
|
||||
}
|
||||
)
|
||||
|
||||
kwargs_mc2.update(stage3_kwargs)
|
||||
return kwargs_mc2
|
||||
@@ -277,57 +280,58 @@ class TokenDispatcherWithMC2(MoETokenDispatcher):
|
||||
def token_combine(self, hidden_states, context_metadata, bias=None):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states,
|
||||
context_metadata)
|
||||
combined_output = torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2) \
|
||||
if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
kwargs_mc2 = self.get_combine_mc_kwargs(hidden_states, context_metadata)
|
||||
combined_output = (
|
||||
torch_npu.npu_moe_distribute_combine_v2(**kwargs_mc2)
|
||||
if self.enable_dispatch_v2
|
||||
else torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
)
|
||||
|
||||
return TokenCombineResult(routed_out=combined_output, )
|
||||
return TokenCombineResult(
|
||||
routed_out=combined_output,
|
||||
)
|
||||
|
||||
|
||||
class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.apply_router_weight_on_input = False
|
||||
self.max_num_tokens = kwargs.get("max_num_tokens")
|
||||
num_experts_local = kwargs.get("num_local_experts", 0)
|
||||
self.num_experts_local = num_experts_local.item() if torch.is_tensor(
|
||||
num_experts_local) else int(num_experts_local)
|
||||
self.num_experts_local = (
|
||||
num_experts_local.item() if torch.is_tensor(num_experts_local) else int(num_experts_local)
|
||||
)
|
||||
self.original_shape = None
|
||||
self.with_quant = False
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
self.original_shape = hidden_states.shape
|
||||
|
||||
num_tokens = hidden_states.shape[:-1].numel()
|
||||
self.apply_router_weight_on_input = apply_router_weight_on_input
|
||||
if self.apply_router_weight_on_input:
|
||||
assert (topk_weights.dim() == 2
|
||||
), "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
assert topk_weights.dim() == 2, "`topk_weights` should be in shape (num_tokens, topk)"
|
||||
_, topk = topk_weights.shape
|
||||
assert (
|
||||
topk == 1
|
||||
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * \
|
||||
topk_weights.to(hidden_states.dtype)
|
||||
assert topk == 1, "Only support topk=1 when `apply_router_weight_on_input` is True"
|
||||
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype)
|
||||
if expert_map is not None:
|
||||
global_num_experts = len(expert_map) + global_redundant_expert_num
|
||||
mask = (expert_map[topk_ids] != -1)
|
||||
mask = expert_map[topk_ids] != -1
|
||||
topk_weights = topk_weights * mask
|
||||
first_expert_idx = get_ep_group(
|
||||
).rank_in_group * self.num_experts_local
|
||||
first_expert_idx = get_ep_group().rank_in_group * self.num_experts_local
|
||||
last_expert_idx = first_expert_idx + self.num_experts_local
|
||||
else:
|
||||
first_expert_idx = 0
|
||||
@@ -344,15 +348,12 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
expert_tokens_num_type=1,
|
||||
expert_tokens_num_flag=True,
|
||||
active_expert_range=[first_expert_idx, last_expert_idx],
|
||||
quant_mode=1
|
||||
if self.with_quant and pertoken_scale is None else -1,
|
||||
))
|
||||
quant_mode=1 if self.with_quant and pertoken_scale is None else -1,
|
||||
)
|
||||
)
|
||||
expert_tokens = expert_tokens.to(torch.int64)
|
||||
group_list_type = 1 # `count` mode
|
||||
context_metadata = {
|
||||
"topk_weights": topk_weights,
|
||||
"expanded_row_idx": expanded_row_idx
|
||||
}
|
||||
context_metadata = {"topk_weights": topk_weights, "expanded_row_idx": expanded_row_idx}
|
||||
|
||||
return TokenDispatchResult(
|
||||
hidden_states=sorted_hidden_states,
|
||||
@@ -367,7 +368,8 @@ class TokenDispatcherWithAllGather(MoETokenDispatcher):
|
||||
final_hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=hidden_states,
|
||||
sorted_indices=torch.abs(context_metadata["expanded_row_idx"]),
|
||||
probs=context_metadata["topk_weights"])
|
||||
probs=context_metadata["topk_weights"],
|
||||
)
|
||||
if len(self.original_shape) == 3:
|
||||
final_hidden_states = final_hidden_states.view(self.original_shape)
|
||||
|
||||
@@ -398,35 +400,33 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
device=torch.npu.current_device(),
|
||||
)
|
||||
|
||||
local_expert_indices_offset = (self.ep_rank * self.num_local_experts)
|
||||
local_expert_indices_offset = self.ep_rank * self.num_local_experts
|
||||
|
||||
self.local_expert_indices = [
|
||||
local_expert_indices_offset + i
|
||||
for i in range(self.num_local_experts)
|
||||
]
|
||||
assert (len(self.local_expert_indices) == self.num_local_experts
|
||||
), "Invalid local expert indices"
|
||||
self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)]
|
||||
assert len(self.local_expert_indices) == self.num_local_experts, "Invalid local expert indices"
|
||||
for i in range(len(self.local_expert_indices) - 1):
|
||||
assert (self.local_expert_indices[i] ==
|
||||
self.local_expert_indices[i + 1] -
|
||||
1), "local_expert_indices must be continuous"
|
||||
assert self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1, (
|
||||
"local_expert_indices must be continuous"
|
||||
)
|
||||
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
local_rank = torch.distributed.get_rank(group=self.ep_group)
|
||||
backend = self.ep_group._get_backend(torch.device("npu"))
|
||||
self.moe_all_to_all_group_name = backend.get_hccl_comm_name(local_rank)
|
||||
|
||||
def token_dispatch(self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: Optional[torch.Tensor] = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: Optional[torch.Tensor] = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: Optional[torch.Tensor] = None):
|
||||
def token_dispatch(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
expert_map: torch.Tensor | None = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
mc2_mask: torch.Tensor | None = None,
|
||||
apply_router_weight_on_input: bool = False,
|
||||
with_quant: bool = False,
|
||||
dynamic_eplb: bool = False,
|
||||
pertoken_scale: torch.Tensor | None = None,
|
||||
):
|
||||
self.with_quant = with_quant
|
||||
self.hidden_shape = hidden_states.shape
|
||||
|
||||
@@ -442,35 +442,32 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
|
||||
dynamic_scale_after_all2all = None
|
||||
if self.with_quant:
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
permutated_local_input_tokens)
|
||||
permutated_local_input_tokens, dynamic_scale = torch_npu.npu_dynamic_quant(permutated_local_input_tokens)
|
||||
_, dynamic_scale_after_all2all, permute2_ep_all_to_all_handle = async_all_to_all(
|
||||
dynamic_scale, output_splits, input_splits, self.ep_group)
|
||||
dynamic_scale, output_splits, input_splits, self.ep_group
|
||||
)
|
||||
permute2_ep_all_to_all_handle.wait()
|
||||
dynamic_scale.untyped_storage().resize_(0)
|
||||
|
||||
_, global_input_tokens, permute1_ep_all_to_all_handle = async_all_to_all(
|
||||
permutated_local_input_tokens, output_splits, input_splits,
|
||||
self.ep_group)
|
||||
permutated_local_input_tokens, output_splits, input_splits, self.ep_group
|
||||
)
|
||||
permute1_ep_all_to_all_handle.wait()
|
||||
permutated_local_input_tokens.untyped_storage().resize_(0)
|
||||
|
||||
# Postprocess
|
||||
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all,
|
||||
global_input_tokens_local_experts_indices)
|
||||
global_input_tokens, dynamic_scale_final, reversed_global_input_permutation_mapping = (
|
||||
self._dispatch_postprocess(
|
||||
global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
||||
)
|
||||
)
|
||||
|
||||
context_metadata = {
|
||||
"input_splits":
|
||||
input_splits,
|
||||
"output_splits":
|
||||
output_splits,
|
||||
"topk_weights":
|
||||
topk_weights,
|
||||
"reversed_local_input_permutation_mapping":
|
||||
reversed_local_input_permutation_mapping,
|
||||
"reversed_global_input_permutation_mapping":
|
||||
reversed_global_input_permutation_mapping
|
||||
"input_splits": input_splits,
|
||||
"output_splits": output_splits,
|
||||
"topk_weights": topk_weights,
|
||||
"reversed_local_input_permutation_mapping": reversed_local_input_permutation_mapping,
|
||||
"reversed_global_input_permutation_mapping": reversed_global_input_permutation_mapping,
|
||||
}
|
||||
|
||||
return TokenDispatchResult(
|
||||
@@ -485,8 +482,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
assert bias is None, "Bias is not supported in MoEAlltoAllvTokenDispatcher."
|
||||
|
||||
# 1. Preprocess using metadata
|
||||
hidden_states = self._combine_preprocess(hidden_states,
|
||||
context_metadata)
|
||||
hidden_states = self._combine_preprocess(hidden_states, context_metadata)
|
||||
|
||||
# 2. AllToAll
|
||||
_, permutated_local_input_tokens, handle = async_all_to_all(
|
||||
@@ -499,8 +495,7 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
hidden_states.untyped_storage().resize_(0)
|
||||
|
||||
# 3. Postprocess using metadata
|
||||
output = self._combine_postprocess(permutated_local_input_tokens,
|
||||
context_metadata)
|
||||
output = self._combine_postprocess(permutated_local_input_tokens, context_metadata)
|
||||
|
||||
return TokenCombineResult(routed_out=output)
|
||||
|
||||
@@ -534,42 +529,39 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
)
|
||||
|
||||
def _preprocess(self, topk_ids: torch.Tensor):
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids,
|
||||
bins=self.num_experts,
|
||||
min=0,
|
||||
max=self.num_experts)
|
||||
num_local_tokens_per_expert = torch.histc(topk_ids, bins=self.num_experts, min=0, max=self.num_experts)
|
||||
|
||||
ep_size = self.ep_size
|
||||
self.num_out_tokens = topk_ids.numel()
|
||||
|
||||
input_splits = (num_local_tokens_per_expert.reshape(
|
||||
ep_size,
|
||||
self.num_local_experts).sum(axis=1).to(torch.device("cpu"),
|
||||
non_blocking=True).numpy())
|
||||
input_splits = (
|
||||
num_local_tokens_per_expert.reshape(ep_size, self.num_local_experts)
|
||||
.sum(axis=1)
|
||||
.to(torch.device("cpu"), non_blocking=True)
|
||||
.numpy()
|
||||
)
|
||||
|
||||
num_global_tokens_per_expert = gather_from_sequence_parallel_region(
|
||||
num_local_tokens_per_expert,
|
||||
group=self.ep_group).reshape(ep_size, self.num_experts)
|
||||
num_global_tokens_per_local_expert = num_global_tokens_per_expert[:, self.local_expert_indices[
|
||||
0]:self.local_expert_indices[-1] + 1]
|
||||
num_local_tokens_per_expert, group=self.ep_group
|
||||
).reshape(ep_size, self.num_experts)
|
||||
num_global_tokens_per_local_expert = num_global_tokens_per_expert[
|
||||
:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1
|
||||
]
|
||||
if num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before sum.")
|
||||
raise ValueError("num_global_tokens_per_local_expert must be set before sum.")
|
||||
|
||||
output_splits = (num_global_tokens_per_local_expert.sum(axis=-1).to(
|
||||
torch.device("cpu"), non_blocking=True).numpy())
|
||||
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(
|
||||
axis=0)
|
||||
output_splits = (
|
||||
num_global_tokens_per_local_expert.sum(axis=-1).to(torch.device("cpu"), non_blocking=True).numpy()
|
||||
)
|
||||
num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(axis=0)
|
||||
|
||||
global_input_tokens_local_experts_indices = None
|
||||
if self.num_local_experts > 1:
|
||||
if num_global_tokens_per_local_expert is None:
|
||||
raise ValueError(
|
||||
"num_global_tokens_per_local_expert must be set before operations."
|
||||
)
|
||||
raise ValueError("num_global_tokens_per_local_expert must be set before operations.")
|
||||
global_input_tokens_local_experts_indices = torch.repeat_interleave(
|
||||
self.expert_ids_per_ep_rank,
|
||||
num_global_tokens_per_local_expert.ravel())
|
||||
self.expert_ids_per_ep_rank, num_global_tokens_per_local_expert.ravel()
|
||||
)
|
||||
else:
|
||||
torch.npu.synchronize()
|
||||
|
||||
@@ -581,45 +573,41 @@ class TokenDispatcherWithAll2AllV(MoETokenDispatcher):
|
||||
global_input_tokens_local_experts_indices,
|
||||
)
|
||||
|
||||
def _dispatch_postprocess(self, global_input_tokens,
|
||||
dynamic_scale_after_all2all,
|
||||
global_input_tokens_local_experts_indices):
|
||||
def _dispatch_postprocess(
|
||||
self, global_input_tokens, dynamic_scale_after_all2all, global_input_tokens_local_experts_indices
|
||||
):
|
||||
# Early return if no local experts or no tokens
|
||||
if self.num_local_experts <= 1:
|
||||
return global_input_tokens, dynamic_scale_after_all2all, None
|
||||
|
||||
# Handle quantized case
|
||||
if self.with_quant:
|
||||
assert global_input_tokens_local_experts_indices is not None, \
|
||||
assert global_input_tokens_local_experts_indices is not None, (
|
||||
"global_input_tokens_local_experts_indices must be provided"
|
||||
)
|
||||
dynamic_scale_after_all2all, _ = torch_npu.npu_moe_token_permute(
|
||||
dynamic_scale_after_all2all.unsqueeze(-1),
|
||||
global_input_tokens_local_experts_indices)
|
||||
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(
|
||||
-1)
|
||||
dynamic_scale_after_all2all.unsqueeze(-1), global_input_tokens_local_experts_indices
|
||||
)
|
||||
dynamic_scale_after_all2all = dynamic_scale_after_all2all.squeeze(-1)
|
||||
|
||||
# Non-quantized case
|
||||
global_input_tokens, reversed_global_input_permutation_mapping = torch_npu.npu_moe_token_permute(
|
||||
global_input_tokens, global_input_tokens_local_experts_indices)
|
||||
global_input_tokens, global_input_tokens_local_experts_indices
|
||||
)
|
||||
return global_input_tokens, dynamic_scale_after_all2all, reversed_global_input_permutation_mapping
|
||||
|
||||
def _combine_preprocess(self, hidden_states: torch.Tensor,
|
||||
context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_preprocess(self, hidden_states: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
||||
# Unpermutation 2: expert output to AlltoAll input
|
||||
if hidden_states.shape[0] > 0 and self.num_local_experts > 1:
|
||||
rev_global = context_metadata[
|
||||
"reversed_global_input_permutation_mapping"]
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(
|
||||
hidden_states, rev_global)
|
||||
rev_global = context_metadata["reversed_global_input_permutation_mapping"]
|
||||
hidden_states = torch_npu.npu_moe_token_unpermute(hidden_states, rev_global)
|
||||
return hidden_states
|
||||
|
||||
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor,
|
||||
context_metadata: dict) -> torch.Tensor:
|
||||
def _combine_postprocess(self, permutated_local_input_tokens: torch.Tensor, context_metadata: dict) -> torch.Tensor:
|
||||
# Unpermutation 1: AlltoAll output to output
|
||||
output = torch_npu.npu_moe_token_unpermute(
|
||||
permuted_tokens=permutated_local_input_tokens,
|
||||
sorted_indices=context_metadata[
|
||||
"reversed_local_input_permutation_mapping"].to(torch.int32),
|
||||
sorted_indices=context_metadata["reversed_local_input_permutation_mapping"].to(torch.int32),
|
||||
probs=context_metadata["topk_weights"],
|
||||
restore_shape=self.hidden_shape_before_permute,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user