[3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph (#2614)
### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No further test cases needed.
- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -33,6 +33,7 @@ from vllm_ascend.distributed.moe_comm_method import ( # isort: skip
|
|||||||
@pytest.mark.parametrize("top_k_num", [2, 4])
|
@pytest.mark.parametrize("top_k_num", [2, 4])
|
||||||
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
|
||||||
@pytest.mark.parametrize("ep_rank", [0, 1])
|
@pytest.mark.parametrize("ep_rank", [0, 1])
|
||||||
|
@pytest.mark.parametrize("apply_a8_quantization", [False])
|
||||||
def test_all_gather_comm_impl(
|
def test_all_gather_comm_impl(
|
||||||
num_tokens,
|
num_tokens,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
@@ -41,6 +42,7 @@ def test_all_gather_comm_impl(
|
|||||||
top_k_num,
|
top_k_num,
|
||||||
dtype,
|
dtype,
|
||||||
ep_rank,
|
ep_rank,
|
||||||
|
apply_a8_quantization,
|
||||||
mocker,
|
mocker,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
@@ -118,8 +120,9 @@ def test_all_gather_comm_impl(
|
|||||||
native_permuted_hidden,
|
native_permuted_hidden,
|
||||||
native_expert_tokens,
|
native_expert_tokens,
|
||||||
_,
|
_,
|
||||||
|
_,
|
||||||
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
) = native_impl.permute(hidden_states, topk_ids, topk_weights, expert_map,
|
||||||
num_experts)
|
num_experts, apply_a8_quantization)
|
||||||
# Simulate MLP output
|
# Simulate MLP output
|
||||||
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
native_mlp_output = torch.randn_like(native_permuted_hidden)
|
||||||
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
native_impl.unpermute(native_mlp_output, native_hidden_states_out)
|
||||||
@@ -130,8 +133,9 @@ def test_all_gather_comm_impl(
|
|||||||
all_gather_permuted_hidden,
|
all_gather_permuted_hidden,
|
||||||
all_gather_expert_tokens,
|
all_gather_expert_tokens,
|
||||||
_,
|
_,
|
||||||
|
_,
|
||||||
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
) = all_gather_impl.permute(hidden_states, topk_ids, topk_weights,
|
||||||
expert_map, num_experts)
|
expert_map, num_experts, apply_a8_quantization)
|
||||||
|
|
||||||
# Use the same simulated MLP output for a fair comparison
|
# Use the same simulated MLP output for a fair comparison
|
||||||
all_gather_mlp_output = native_mlp_output.clone()
|
all_gather_mlp_output = native_mlp_output.clone()
|
||||||
|
|||||||
@@ -107,4 +107,4 @@ def test_models_distributed_Qwen3_MOE_TP2_WITH_ACLGRAPH():
|
|||||||
tensor_parallel_size=2,
|
tensor_parallel_size=2,
|
||||||
enforce_eager=False,
|
enforce_eager=False,
|
||||||
) as vllm_model:
|
) as vllm_model:
|
||||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||||
|
|||||||
@@ -54,7 +54,8 @@ class MoECommMethod(ABC):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
expert_map: torch.Tensor,
|
expert_map: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
apply_a8_quantization: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||||
"""Pre-process before MLP.
|
"""Pre-process before MLP.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -64,6 +65,7 @@ class MoECommMethod(ABC):
|
|||||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||||
Mapping from global expert IDs to local expert IDs.
|
Mapping from global expert IDs to local expert IDs.
|
||||||
num_experts (int): Number of local experts (experts on this device).
|
num_experts (int): Number of local experts (experts on this device).
|
||||||
|
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||||
@@ -72,6 +74,8 @@ class MoECommMethod(ABC):
|
|||||||
hidden_states based on topk_ids.
|
hidden_states based on topk_ids.
|
||||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||||
Number of tokens assigned to each expert.
|
Number of tokens assigned to each expert.
|
||||||
|
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
|
||||||
|
Dynamic scale for each expert, used for quantization.
|
||||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||||
to determine how to handle the output.
|
to determine how to handle the output.
|
||||||
@@ -159,7 +163,8 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
expert_map: torch.Tensor, # noqa: F841
|
expert_map: torch.Tensor, # noqa: F841
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
apply_a8_quantization: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||||
num_tokens = hidden_states.shape[0]
|
num_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
self.topk_weights = topk_weights
|
self.topk_weights = topk_weights
|
||||||
@@ -194,7 +199,7 @@ class AllGatherCommImpl(MoECommMethod):
|
|||||||
|
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||||
|
|
||||||
def unpermute(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
@@ -219,7 +224,8 @@ class NativeAllGatherCommImpl(AllGatherCommImpl):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
expert_map: torch.Tensor,
|
expert_map: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
apply_a8_quantization: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||||
num_tokens = hidden_states.shape[0]
|
num_tokens = hidden_states.shape[0]
|
||||||
|
|
||||||
# Generate token indices and flatten
|
# Generate token indices and flatten
|
||||||
@@ -269,7 +275,7 @@ class NativeAllGatherCommImpl(AllGatherCommImpl):
|
|||||||
|
|
||||||
group_list_type = 1 # `count` mode
|
group_list_type = 1 # `count` mode
|
||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||||
|
|
||||||
def unpermute(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
@@ -375,7 +381,8 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
topk_weights: torch.Tensor,
|
topk_weights: torch.Tensor,
|
||||||
expert_map: torch.Tensor,
|
expert_map: torch.Tensor,
|
||||||
num_experts: int,
|
num_experts: int,
|
||||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
apply_a8_quantization: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||||
# Store tensors needed for post_process
|
# Store tensors needed for post_process
|
||||||
self.topk_ids = topk_ids
|
self.topk_ids = topk_ids
|
||||||
self.topk_weights = topk_weights.to(torch.float32)
|
self.topk_weights = topk_weights.to(torch.float32)
|
||||||
@@ -388,7 +395,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
"moe_expert_num": self.moe_config.num_experts,
|
"moe_expert_num": self.moe_config.num_experts,
|
||||||
"global_bs": 0,
|
"global_bs": 0,
|
||||||
"scales": None,
|
"scales": None,
|
||||||
"quant_mode": 0,
|
"quant_mode": 2 if apply_a8_quantization else 0,
|
||||||
"group_ep": self.mc2_comm_name,
|
"group_ep": self.mc2_comm_name,
|
||||||
"ep_world_size": self.moe_config.ep_size,
|
"ep_world_size": self.moe_config.ep_size,
|
||||||
"ep_rank_id": self.moe_config.ep_rank,
|
"ep_rank_id": self.moe_config.ep_rank,
|
||||||
@@ -409,7 +416,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
|
|
||||||
(
|
(
|
||||||
permuted_hidden_states,
|
permuted_hidden_states,
|
||||||
_, # dynamic_scale is not used
|
dynamic_scale,
|
||||||
self.assist_info_for_combine,
|
self.assist_info_for_combine,
|
||||||
expert_tokens,
|
expert_tokens,
|
||||||
self.ep_recv_counts,
|
self.ep_recv_counts,
|
||||||
@@ -418,7 +425,7 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
|
|
||||||
group_list_type = 1
|
group_list_type = 1
|
||||||
|
|
||||||
return permuted_hidden_states, expert_tokens, group_list_type
|
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
|
||||||
|
|
||||||
def unpermute(self, mlp_output: torch.Tensor,
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
hidden_states: torch.Tensor) -> None:
|
hidden_states: torch.Tensor) -> None:
|
||||||
@@ -457,3 +464,93 @@ class MC2CommImpl(MoECommMethod):
|
|||||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||||
|
|
||||||
hidden_states[:] = combine(**combine_kwargs)
|
hidden_states[:] = combine(**combine_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
class AlltoAllCommImpl(MoECommMethod):
|
||||||
|
"""This implementation is for the scenarios listed below:
|
||||||
|
1. `enable_expert_parallel=True`.
|
||||||
|
2. `npu_grouped_matmul` is available.
|
||||||
|
|
||||||
|
This implementation uses all-to-all communication to exchange tokens
|
||||||
|
between data parallel ranks before and after the MLP computation. It should
|
||||||
|
have better performance than AllGatherCommImpl when DP size > 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||||
|
super().__init__(moe_config)
|
||||||
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||||
|
get_token_dispatcher
|
||||||
|
self.token_dispatcher = get_token_dispatcher(
|
||||||
|
"TokenDispatcherWithAll2AllV")
|
||||||
|
self._restore_tp_across_dp()
|
||||||
|
|
||||||
|
def _restore_tp_across_dp(self):
|
||||||
|
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||||
|
# tp_size and tp_rank.
|
||||||
|
self.tp_size = get_tensor_model_parallel_world_size()
|
||||||
|
self.tp_rank = get_tensor_model_parallel_rank()
|
||||||
|
|
||||||
|
def prepare(
|
||||||
|
self, hidden_states: torch.Tensor,
|
||||||
|
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
self.num_tokens, _ = hidden_states.shape
|
||||||
|
pad_size = self.tp_size - self.num_tokens
|
||||||
|
|
||||||
|
if pad_size > 0:
|
||||||
|
hidden_states = nn.functional.pad(hidden_states,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
router_logits = nn.functional.pad(router_logits,
|
||||||
|
(0, 0, 0, pad_size))
|
||||||
|
|
||||||
|
if self.tp_size > 1:
|
||||||
|
split_hidden_states = torch.tensor_split(hidden_states,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
split_router_logits = torch.tensor_split(router_logits,
|
||||||
|
self.tp_size,
|
||||||
|
dim=0)
|
||||||
|
self.split_hidden_states = split_hidden_states
|
||||||
|
|
||||||
|
hidden_states = split_hidden_states[self.tp_rank]
|
||||||
|
router_logits = split_router_logits[self.tp_rank]
|
||||||
|
|
||||||
|
return hidden_states, router_logits
|
||||||
|
|
||||||
|
def finalize(self, hidden_states: torch.Tensor,
|
||||||
|
reduce_results: bool) -> torch.Tensor:
|
||||||
|
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||||
|
|
||||||
|
Also, unpad the hidden states if needed.
|
||||||
|
"""
|
||||||
|
if self.tp_size > 1:
|
||||||
|
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||||
|
self.moe_config.tp_group.device_group)
|
||||||
|
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||||
|
|
||||||
|
if self.num_tokens < hidden_states.shape[0]:
|
||||||
|
hidden_states = hidden_states[:self.num_tokens]
|
||||||
|
|
||||||
|
return hidden_states
|
||||||
|
|
||||||
|
def permute(
|
||||||
|
self,
|
||||||
|
hidden_states: torch.Tensor,
|
||||||
|
topk_ids: torch.Tensor,
|
||||||
|
topk_weights: torch.Tensor,
|
||||||
|
expert_map: torch.Tensor,
|
||||||
|
num_experts: int,
|
||||||
|
apply_a8_quantization: bool,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||||
|
results = self.token_dispatcher.token_dispatch(
|
||||||
|
hidden_states,
|
||||||
|
topk_weights,
|
||||||
|
topk_ids,
|
||||||
|
None,
|
||||||
|
log2phy=None,
|
||||||
|
with_quant=apply_a8_quantization)
|
||||||
|
return results["hidden_states"], results["group_list"], results[
|
||||||
|
"dynamic_scale"], results["group_list_type"]
|
||||||
|
|
||||||
|
def unpermute(self, mlp_output: torch.Tensor,
|
||||||
|
hidden_states: torch.Tensor) -> None:
|
||||||
|
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)
|
||||||
|
|||||||
@@ -18,6 +18,7 @@
|
|||||||
from typing import Any, Callable, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch_npu
|
||||||
from vllm.config import CompilationLevel, get_current_vllm_config
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||||
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
from vllm.distributed import get_dp_group, get_ep_group, get_tp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
@@ -26,12 +27,14 @@ from vllm.model_executor.layers.fused_moe.layer import (
|
|||||||
|
|
||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
from vllm_ascend.distributed.moe_comm_method import (AllGatherCommImpl,
|
||||||
MC2CommImpl,
|
AlltoAllCommImpl,
|
||||||
MoECommMethod)
|
MC2CommImpl)
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
from vllm_ascend.ops.fused_moe import apply_mlp, fused_experts_moge
|
from vllm_ascend.ops.fused_moe import fused_experts_moge
|
||||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||||
from vllm_ascend.utils import is_310p
|
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||||
|
setup_token_dispatchers
|
||||||
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, is_310p
|
||||||
|
|
||||||
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
original_unquantized_fused_moe_init_func = UnquantizedFusedMoEMethod.__init__
|
||||||
|
|
||||||
@@ -52,7 +55,6 @@ def fused_experts(
|
|||||||
w2_scale: Optional[torch.Tensor] = None,
|
w2_scale: Optional[torch.Tensor] = None,
|
||||||
w1_scale_bias: torch.Tensor = None,
|
w1_scale_bias: torch.Tensor = None,
|
||||||
w2_scale_bias: torch.Tensor = None,
|
w2_scale_bias: torch.Tensor = None,
|
||||||
moe_comm_method: Optional[MoECommMethod] = None,
|
|
||||||
# For TorchAir graph
|
# For TorchAir graph
|
||||||
is_torchair: bool = False,
|
is_torchair: bool = False,
|
||||||
# For Cube/Vector parallel
|
# For Cube/Vector parallel
|
||||||
@@ -64,9 +66,8 @@ def fused_experts(
|
|||||||
global_redundant_expert_num: int = 0,
|
global_redundant_expert_num: int = 0,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# Check constraints
|
# Check constraints
|
||||||
assert hidden_states.shape[1] == w1.shape[2], (
|
assert hidden_states.shape[1] == w1.shape[1], (
|
||||||
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[2]}")
|
f"Hidden size mismatch {hidden_states.shape[1]} != {w1.shape[1]}")
|
||||||
|
|
||||||
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
|
||||||
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
|
||||||
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
assert w1.stride(-1) == 1, "Stride of last dimension must be 1"
|
||||||
@@ -74,31 +75,79 @@ def fused_experts(
|
|||||||
assert hidden_states.dtype in [
|
assert hidden_states.dtype in [
|
||||||
torch.float32, torch.float16, torch.bfloat16
|
torch.float32, torch.float16, torch.bfloat16
|
||||||
]
|
]
|
||||||
|
if (use_int8_w8a8 or use_int4_w4a8):
|
||||||
|
assert w1_scale is not None and w2_scale is not None, \
|
||||||
|
"INT8 quantization requires weight scales."
|
||||||
|
|
||||||
|
w1_scale = w1_scale.to(torch.float32)
|
||||||
|
down_scale = [w2_scale]
|
||||||
|
down_output_dtype = w2_scale.dtype
|
||||||
|
else:
|
||||||
|
down_scale = None
|
||||||
|
down_output_dtype = None
|
||||||
|
|
||||||
|
moe_comm_method = get_forward_context().moe_comm_method
|
||||||
assert moe_comm_method is not None, "Missing communication context"
|
assert moe_comm_method is not None, "Missing communication context"
|
||||||
|
|
||||||
num_experts = w1.shape[0]
|
num_experts = w1.shape[0]
|
||||||
|
|
||||||
permuted_hidden_states, expert_tokens, group_list_type = moe_comm_method.permute(
|
permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type = moe_comm_method.permute(
|
||||||
hidden_states, topk_ids, topk_weights, expert_map, num_experts)
|
hidden_states, topk_ids, topk_weights, expert_map, num_experts,
|
||||||
mlp_output = apply_mlp(
|
use_int8_w8a8 or use_int4_w4a8)
|
||||||
permuted_hidden_states,
|
|
||||||
w1,
|
gate_up_output = torch_npu.npu_grouped_matmul(
|
||||||
w2,
|
x=[permuted_hidden_states],
|
||||||
expert_tokens,
|
weight=[w1],
|
||||||
|
split_item=2,
|
||||||
group_list_type=group_list_type,
|
group_list_type=group_list_type,
|
||||||
)
|
group_type=0,
|
||||||
moe_comm_method.unpermute(mlp_output, hidden_states)
|
group_list=expert_tokens,
|
||||||
|
output_dtype=torch.int32 if use_int8_w8a8 else None,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
if (use_int8_w8a8 or use_int4_w4a8):
|
||||||
|
activated_output, activated_output_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||||
|
x=gate_up_output,
|
||||||
|
weight_scale=w1_scale,
|
||||||
|
activation_scale=dynamic_scale,
|
||||||
|
bias=None,
|
||||||
|
quant_scale=None,
|
||||||
|
quant_offset=None,
|
||||||
|
group_index=expert_tokens,
|
||||||
|
activate_left=True,
|
||||||
|
quant_mode=1,
|
||||||
|
)
|
||||||
|
activated_output_scale = [activated_output_scale]
|
||||||
|
else:
|
||||||
|
activated_output = torch_npu.npu_swiglu(gate_up_output)
|
||||||
|
activated_output_scale = None
|
||||||
|
|
||||||
|
down_output = torch_npu.npu_grouped_matmul(
|
||||||
|
x=[activated_output],
|
||||||
|
weight=[w2],
|
||||||
|
scale=down_scale,
|
||||||
|
per_token_scale=activated_output_scale,
|
||||||
|
split_item=2,
|
||||||
|
group_list_type=group_list_type,
|
||||||
|
group_type=0,
|
||||||
|
group_list=expert_tokens,
|
||||||
|
output_dtype=down_output_dtype,
|
||||||
|
)[0]
|
||||||
|
|
||||||
|
moe_comm_method.unpermute(down_output, hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
def unquantized_fused_moe_init_func(self, *args, **kwargs):
|
||||||
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
original_unquantized_fused_moe_init_func(self, *args, **kwargs)
|
||||||
|
|
||||||
|
# NOTE: Currently, this self.use_aclgraph is only used in
|
||||||
|
# UnquantizedFusedMoEMethod.forward_oot to decide whether to use in
|
||||||
|
# ops/fused_moe.py:568 to circumvent torch.randint_like not supported issue.
|
||||||
|
# Once torch.randint_like is supported or removed, this flag can be removed.
|
||||||
vllm_config = get_current_vllm_config()
|
vllm_config = get_current_vllm_config()
|
||||||
self.max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
|
||||||
|
|
||||||
ascend_config = get_ascend_config()
|
ascend_config = get_ascend_config()
|
||||||
|
|
||||||
if ascend_config.torchair_graph_config.enabled:
|
if ascend_config.torchair_graph_config.enabled:
|
||||||
self.use_aclgraph = False
|
self.use_aclgraph = False
|
||||||
else:
|
else:
|
||||||
@@ -156,8 +205,6 @@ def forward_oot(
|
|||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
apply_router_weight_on_input=apply_router_weight_on_input)
|
apply_router_weight_on_input=apply_router_weight_on_input)
|
||||||
|
|
||||||
moe_comm_method = get_forward_context().moe_comm_method
|
|
||||||
|
|
||||||
return fused_experts(
|
return fused_experts(
|
||||||
hidden_states=x,
|
hidden_states=x,
|
||||||
w1=layer.w13_weight,
|
w1=layer.w13_weight,
|
||||||
@@ -166,10 +213,26 @@ def forward_oot(
|
|||||||
topk_ids=topk_ids,
|
topk_ids=topk_ids,
|
||||||
global_num_experts=global_num_experts,
|
global_num_experts=global_num_experts,
|
||||||
expert_map=expert_map,
|
expert_map=expert_map,
|
||||||
moe_comm_method=moe_comm_method,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def process_weights_after_loading(self, layer):
|
||||||
|
super(UnquantizedFusedMoEMethod, self).process_weights_after_loading(layer)
|
||||||
|
w13_data = self._maybe_pad_weight(layer.w13_weight.data).transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w13_weight = torch.nn.Parameter(w13_data, requires_grad=False)
|
||||||
|
|
||||||
|
w2_data = self._maybe_pad_weight(layer.w2_weight.data).transpose(
|
||||||
|
1, 2).contiguous()
|
||||||
|
layer.w2_weight = torch.nn.Parameter(w2_data, requires_grad=False)
|
||||||
|
|
||||||
|
if not is_310p():
|
||||||
|
layer.w13_weight.data = torch_npu.npu_format_cast(
|
||||||
|
layer.w13_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
layer.w2_weight.data = torch_npu.npu_format_cast(
|
||||||
|
layer.w2_weight.data, ACL_FORMAT_FRACTAL_NZ)
|
||||||
|
|
||||||
|
|
||||||
class AscendFusedMoE(FusedMoE):
|
class AscendFusedMoE(FusedMoE):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -224,12 +287,17 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
has_bias,
|
has_bias,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
setup_token_dispatchers(self.moe_config.ep_size,
|
||||||
|
top_k=self.top_k,
|
||||||
|
num_experts=self.global_num_experts,
|
||||||
|
num_local_experts=self.local_num_experts)
|
||||||
|
|
||||||
self.moe_config.tp_group = get_tp_group()
|
self.moe_config.tp_group = get_tp_group()
|
||||||
self.moe_config.dp_group = get_dp_group()
|
self.moe_config.dp_group = get_dp_group()
|
||||||
self.moe_config.ep_group = get_ep_group()
|
self.moe_config.ep_group = get_ep_group()
|
||||||
self.moe_config.mc2_group = get_mc2_group()
|
self.moe_config.mc2_group = get_mc2_group()
|
||||||
|
|
||||||
for method in {AllGatherCommImpl, MC2CommImpl}:
|
for method in {AllGatherCommImpl, AlltoAllCommImpl, MC2CommImpl}:
|
||||||
setattr(
|
setattr(
|
||||||
self, method.__name__.lower(),
|
self, method.__name__.lower(),
|
||||||
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
method(moe_config=self.moe_config)) # type: ignore[abstract]
|
||||||
@@ -282,4 +350,5 @@ class AscendFusedMoE(FusedMoE):
|
|||||||
|
|
||||||
|
|
||||||
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
UnquantizedFusedMoEMethod.__init__ = unquantized_fused_moe_init_func
|
||||||
|
UnquantizedFusedMoEMethod.process_weights_after_loading = process_weights_after_loading
|
||||||
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
UnquantizedFusedMoEMethod.forward_oot = forward_oot
|
||||||
|
|||||||
@@ -230,7 +230,6 @@ def fused_experts_moge(
|
|||||||
0, sorted_topk_ids).unsqueeze(-1)
|
0, sorted_topk_ids).unsqueeze(-1)
|
||||||
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
group_list = num_tokens_per_expert.cumsum(dim=0).to(torch.int64)
|
||||||
|
|
||||||
w1 = w1.transpose(1, 2)
|
|
||||||
gate_up_out = torch_npu.npu_grouped_matmul(
|
gate_up_out = torch_npu.npu_grouped_matmul(
|
||||||
x=[sorted_hidden_states],
|
x=[sorted_hidden_states],
|
||||||
weight=[w1],
|
weight=[w1],
|
||||||
@@ -247,7 +246,6 @@ def fused_experts_moge(
|
|||||||
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
gate_up_out = torch_npu.npu_swiglu(gate_up_out)
|
||||||
gate_up_out *= topk_scales
|
gate_up_out *= topk_scales
|
||||||
|
|
||||||
w2 = w2.transpose(1, 2)
|
|
||||||
down_out_list = torch_npu.npu_grouped_matmul(
|
down_out_list = torch_npu.npu_grouped_matmul(
|
||||||
x=[gate_up_out],
|
x=[gate_up_out],
|
||||||
weight=[w2],
|
weight=[w2],
|
||||||
|
|||||||
@@ -19,12 +19,16 @@ from typing import Any, Callable, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch_npu
|
import torch_npu
|
||||||
|
from vllm.config import CompilationLevel, get_current_vllm_config
|
||||||
from vllm.distributed import get_ep_group
|
from vllm.distributed import get_ep_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
|
|
||||||
import vllm_ascend.envs as envs_ascend
|
import vllm_ascend.envs as envs_ascend
|
||||||
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.ascend_forward_context import FusedMoEState
|
from vllm_ascend.ascend_forward_context import FusedMoEState
|
||||||
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
from vllm_ascend.distributed.parallel_state import get_mc2_group
|
||||||
|
from vllm_ascend.ops.common_fused_moe import \
|
||||||
|
fused_experts as unified_fused_experts
|
||||||
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
from vllm_ascend.ops.fused_moe import unified_fused_experts_eager
|
||||||
from vllm_ascend.ops.layers.experts_selector import select_experts
|
from vllm_ascend.ops.layers.experts_selector import select_experts
|
||||||
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
|
from vllm_ascend.utils import ACL_FORMAT_FRACTAL_NZ, dispose_tensor
|
||||||
@@ -283,6 +287,13 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
|
|
||||||
self.ep_group = get_ep_group()
|
self.ep_group = get_ep_group()
|
||||||
|
|
||||||
|
vllm_config = get_current_vllm_config()
|
||||||
|
ascend_config = get_ascend_config()
|
||||||
|
self.use_aclgraph = (
|
||||||
|
vllm_config.compilation_config.level == CompilationLevel.PIECEWISE
|
||||||
|
and not vllm_config.model_config.enforce_eager
|
||||||
|
and not ascend_config.torchair_graph_config.enabled)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
device_group = get_mc2_group().device_group
|
device_group = get_mc2_group().device_group
|
||||||
# TODO: Try local_rank = ep_group.rank_in_group
|
# TODO: Try local_rank = ep_group.rank_in_group
|
||||||
@@ -375,6 +386,19 @@ class AscendW8A8DynamicFusedMoEMethod:
|
|||||||
e_score_correction_bias=e_score_correction_bias,
|
e_score_correction_bias=e_score_correction_bias,
|
||||||
global_num_experts=global_num_experts)
|
global_num_experts=global_num_experts)
|
||||||
|
|
||||||
|
if self.use_aclgraph:
|
||||||
|
return unified_fused_experts(
|
||||||
|
hidden_states=x,
|
||||||
|
w1=layer.w13_weight,
|
||||||
|
w2=layer.w2_weight,
|
||||||
|
topk_weights=topk_weights,
|
||||||
|
topk_ids=topk_ids,
|
||||||
|
use_int8_w8a8=True,
|
||||||
|
w1_scale=layer.w13_weight_scale,
|
||||||
|
w2_scale=layer.w2_weight_scale,
|
||||||
|
expert_map=expert_map,
|
||||||
|
)
|
||||||
|
|
||||||
fused_moe_state = get_forward_context().fused_moe_state
|
fused_moe_state = get_forward_context().fused_moe_state
|
||||||
shared_gate_up, shared_dequant_scale = None, None
|
shared_gate_up, shared_dequant_scale = None, None
|
||||||
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
if shared_experts is not None and fused_moe_state == FusedMoEState.MC2:
|
||||||
|
|||||||
@@ -89,7 +89,8 @@ from vllm_ascend.sample.rejection_sampler import AscendRejectionSampler
|
|||||||
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
from vllm_ascend.torchair.torchair_attention import AscendTorchairMetadata
|
||||||
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
|
from vllm_ascend.torchair.torchair_mla import AscendMLATorchairMetadata
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
||||||
ProfileExecuteDuration, is_310p,
|
AscendSocVersion, ProfileExecuteDuration,
|
||||||
|
get_ascend_soc_version, is_310p,
|
||||||
lmhead_tp_enable, vllm_version_is)
|
lmhead_tp_enable, vllm_version_is)
|
||||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||||
@@ -1620,8 +1621,22 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _select_moe_comm_method(self, num_tokens: int) -> str:
|
def _select_moe_comm_method(self, num_tokens: int) -> str:
|
||||||
return ("mc2"
|
soc_version = get_ascend_soc_version()
|
||||||
if num_tokens <= self.mc2_tokens_capacity else "allgather")
|
|
||||||
|
if num_tokens <= self.mc2_tokens_capacity:
|
||||||
|
moe_comm_method = "mc2"
|
||||||
|
elif soc_version in {AscendSocVersion.A2}:
|
||||||
|
moe_comm_method = "allgather"
|
||||||
|
elif soc_version in {AscendSocVersion.A3}:
|
||||||
|
moe_comm_method = "alltoall"
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported soc_version: {soc_version}")
|
||||||
|
|
||||||
|
if is_global_first_rank():
|
||||||
|
logger.debug(f"num_tokens: {num_tokens}, "
|
||||||
|
f"moe_comm_method: {moe_comm_method}")
|
||||||
|
|
||||||
|
return moe_comm_method
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def execute_model(
|
def execute_model(
|
||||||
|
|||||||
Reference in New Issue
Block a user