[3/N][Feat][Graph] Support all-to-all and quantized models with ACL Graph (#2614)
### What this PR does / why we need it?
* **Unify execution paths:** Consolidates the quantized and
non-quantized execution paths into a single `fused_experts` function,
removing duplicated logic and making the control flow clearer and easier
to maintain.
* **W8A8 dynamic quantization:** Adds support for W8A8 dynamic
quantization inside the unified MoE kernel. Communication routines are
updated to correctly handle dynamic quantization scales for activations.
* **Weight pre-processing:** Prae-transpose the `w13` and `w2` weight
matrices (as implemented in PR #2025) so that quantized and
non-quantized models follow the same code path for the MoE gating,
up-projection, and down-projection operations.
* **All-to-all communication:** Adds an `all-to-all` collective
communication pattern. For large token counts on modern hardware,
`all-to-all` is more efficient than the previous `all-gather` strategy.
However, `all-to-all` is not really captured and replayed due to
multiple D2H operations which will trigger synchronization, and thus
raise error when capture graphs. We only use `all-to-all` when fallback
to `compiled_graph_for_general_shape`.
* **Dynamic communication selection:** The model runner now selects the
optimal MoE communication method (`mc2`, `allgather`, or `alltoall`) at
runtime based on token count and the Ascend SoC version.
* **Limitation:** `all-gather` is not yet supported for quantized
models, which means there is still something left to do on A2.
### Does this PR introduce _any_ user-facing change?
None.
### How was this patch tested?
No further test cases needed.
- vLLM version: v0.10.1.1
- vLLM main:
d660c98c1b
---------
Signed-off-by: Yizhou Liu <liu_yizhou@outlook.com>
This commit is contained in:
@@ -54,7 +54,8 @@ class MoECommMethod(ABC):
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
"""Pre-process before MLP.
|
||||
|
||||
Args:
|
||||
@@ -64,6 +65,7 @@ class MoECommMethod(ABC):
|
||||
expert_map (torch.Tensor): Tensor of shape (global_num_experts, )
|
||||
Mapping from global expert IDs to local expert IDs.
|
||||
num_experts (int): Number of local experts (experts on this device).
|
||||
apply_a8_quantization (bool): Whether to apply A8 quantization (W4A8 and W8A8).
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor, int]: Return a tuple containing:
|
||||
@@ -72,6 +74,8 @@ class MoECommMethod(ABC):
|
||||
hidden_states based on topk_ids.
|
||||
- expert_tokens (torch.Tensor): Tensor of shape (num_experts, )
|
||||
Number of tokens assigned to each expert.
|
||||
- dynamic_scale (torch.Tensor, optional): Tensor of shape (num_experts, )
|
||||
Dynamic scale for each expert, used for quantization.
|
||||
- group_list_type (int): Type of group list, 0 for `cumsum`
|
||||
and 1 for `count`. This is mainly for `npu_grouped_matmul`
|
||||
to determine how to handle the output.
|
||||
@@ -159,7 +163,8 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor, # noqa: F841
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
self.topk_weights = topk_weights
|
||||
@@ -194,7 +199,7 @@ class AllGatherCommImpl(MoECommMethod):
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
@@ -219,7 +224,8 @@ class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
# Generate token indices and flatten
|
||||
@@ -269,7 +275,7 @@ class NativeAllGatherCommImpl(AllGatherCommImpl):
|
||||
|
||||
group_list_type = 1 # `count` mode
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
return permuted_hidden_states, expert_tokens, None, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
@@ -375,7 +381,8 @@ class MC2CommImpl(MoECommMethod):
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, int]:
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
# Store tensors needed for post_process
|
||||
self.topk_ids = topk_ids
|
||||
self.topk_weights = topk_weights.to(torch.float32)
|
||||
@@ -388,7 +395,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
"moe_expert_num": self.moe_config.num_experts,
|
||||
"global_bs": 0,
|
||||
"scales": None,
|
||||
"quant_mode": 0,
|
||||
"quant_mode": 2 if apply_a8_quantization else 0,
|
||||
"group_ep": self.mc2_comm_name,
|
||||
"ep_world_size": self.moe_config.ep_size,
|
||||
"ep_rank_id": self.moe_config.ep_rank,
|
||||
@@ -409,7 +416,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
|
||||
(
|
||||
permuted_hidden_states,
|
||||
_, # dynamic_scale is not used
|
||||
dynamic_scale,
|
||||
self.assist_info_for_combine,
|
||||
expert_tokens,
|
||||
self.ep_recv_counts,
|
||||
@@ -418,7 +425,7 @@ class MC2CommImpl(MoECommMethod):
|
||||
|
||||
group_list_type = 1
|
||||
|
||||
return permuted_hidden_states, expert_tokens, group_list_type
|
||||
return permuted_hidden_states, expert_tokens, dynamic_scale, group_list_type
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
@@ -457,3 +464,93 @@ class MC2CommImpl(MoECommMethod):
|
||||
combine = torch_npu.npu_moe_distribute_combine_v2 if self.enable_dispatch_v2 else torch_npu.npu_moe_distribute_combine
|
||||
|
||||
hidden_states[:] = combine(**combine_kwargs)
|
||||
|
||||
|
||||
class AlltoAllCommImpl(MoECommMethod):
|
||||
"""This implementation is for the scenarios listed below:
|
||||
1. `enable_expert_parallel=True`.
|
||||
2. `npu_grouped_matmul` is available.
|
||||
|
||||
This implementation uses all-to-all communication to exchange tokens
|
||||
between data parallel ranks before and after the MLP computation. It should
|
||||
have better performance than AllGatherCommImpl when DP size > 1.
|
||||
"""
|
||||
|
||||
def __init__(self, moe_config: Optional[FusedMoEConfig]):
|
||||
super().__init__(moe_config)
|
||||
from vllm_ascend.ops.moe_dispatcher.token_dispatcher import \
|
||||
get_token_dispatcher
|
||||
self.token_dispatcher = get_token_dispatcher(
|
||||
"TokenDispatcherWithAll2AllV")
|
||||
self._restore_tp_across_dp()
|
||||
|
||||
def _restore_tp_across_dp(self):
|
||||
# NOTE: Since vLLM flatten tp across dp, we need to restore the original
|
||||
# tp_size and tp_rank.
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
|
||||
def prepare(
|
||||
self, hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
self.num_tokens, _ = hidden_states.shape
|
||||
pad_size = self.tp_size - self.num_tokens
|
||||
|
||||
if pad_size > 0:
|
||||
hidden_states = nn.functional.pad(hidden_states,
|
||||
(0, 0, 0, pad_size))
|
||||
router_logits = nn.functional.pad(router_logits,
|
||||
(0, 0, 0, pad_size))
|
||||
|
||||
if self.tp_size > 1:
|
||||
split_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
split_router_logits = torch.tensor_split(router_logits,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
self.split_hidden_states = split_hidden_states
|
||||
|
||||
hidden_states = split_hidden_states[self.tp_rank]
|
||||
router_logits = split_router_logits[self.tp_rank]
|
||||
|
||||
return hidden_states, router_logits
|
||||
|
||||
def finalize(self, hidden_states: torch.Tensor,
|
||||
reduce_results: bool) -> torch.Tensor:
|
||||
"""If TP size > 1, all-gather the hidden states to get the final output.
|
||||
|
||||
Also, unpad the hidden states if needed.
|
||||
"""
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(self.split_hidden_states), hidden_states,
|
||||
self.moe_config.tp_group.device_group)
|
||||
hidden_states = torch.cat(self.split_hidden_states, dim=0)
|
||||
|
||||
if self.num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:self.num_tokens]
|
||||
|
||||
return hidden_states
|
||||
|
||||
def permute(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
expert_map: torch.Tensor,
|
||||
num_experts: int,
|
||||
apply_a8_quantization: bool,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor], int]:
|
||||
results = self.token_dispatcher.token_dispatch(
|
||||
hidden_states,
|
||||
topk_weights,
|
||||
topk_ids,
|
||||
None,
|
||||
log2phy=None,
|
||||
with_quant=apply_a8_quantization)
|
||||
return results["hidden_states"], results["group_list"], results[
|
||||
"dynamic_scale"], results["group_list_type"]
|
||||
|
||||
def unpermute(self, mlp_output: torch.Tensor,
|
||||
hidden_states: torch.Tensor) -> None:
|
||||
hidden_states[:] = self.token_dispatcher.token_combine(mlp_output)
|
||||
|
||||
Reference in New Issue
Block a user