Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -15,19 +15,19 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import torchair as tng # type: ignore
|
||||
from vllm.distributed import GroupCoordinator, tensor_model_parallel_all_reduce
|
||||
from vllm.distributed import GroupCoordinator
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import select_experts
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
from vllm_ascend.utils import (dispose_tensor, npu_stream_switch,
|
||||
npu_wait_tensor)
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
@@ -39,8 +39,7 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
group_list: torch.Tensor,
|
||||
dynamic_scale: torch.Tensor = None,
|
||||
group_list_type: int = 1,
|
||||
**kwargs) -> torch.Tensor:
|
||||
group_list_type: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
apply MLP: gate_up_proj -> swiglu -> down_proj
|
||||
|
||||
@@ -74,23 +73,6 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
else:
|
||||
pertoken_scale = dynamic_scale
|
||||
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_gate_up = kwargs.get('shared_gate_up', None)
|
||||
shared_dynamic_scale = kwargs.get('shared_dynamic_scale', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_gate_up, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=shared_gate_up,
|
||||
weight_scale=shared_experts.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=shared_dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
|
||||
# gmm1: gate_up_proj
|
||||
hidden_states = torch_npu.npu_grouped_matmul(
|
||||
x=[hidden_states],
|
||||
@@ -120,36 +102,24 @@ def apply_mlp(hidden_states: torch.Tensor,
|
||||
group_list=group_list,
|
||||
output_dtype=w2_scale.dtype)[0]
|
||||
|
||||
if shared_experts:
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_x, hidden_states)
|
||||
shared_output = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.down_proj.weight,
|
||||
shared_experts.down_proj.weight_scale,
|
||||
pertoken_scale=shared_dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if shared_experts.down_proj.reduce_results and shared_experts.down_proj.tp_size > 1:
|
||||
shared_output = tensor_model_parallel_all_reduce(shared_output)
|
||||
if shared_experts:
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
**kwargs) -> torch.Tensor:
|
||||
def fused_experts_with_mc2(
|
||||
hidden_states: torch.Tensor,
|
||||
w1: torch.Tensor,
|
||||
w2: torch.Tensor,
|
||||
w1_scale: torch.Tensor,
|
||||
w2_scale: torch.Tensor,
|
||||
topk_weights: torch.Tensor,
|
||||
topk_ids: torch.Tensor,
|
||||
top_k: int,
|
||||
expert_map: torch.Tensor = None,
|
||||
moe_all_to_all_group_name: str = "",
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
||||
if log2phy:
|
||||
topk_ids = log2phy[topk_ids]
|
||||
global_bs = 0
|
||||
@@ -188,31 +158,17 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
}
|
||||
kwargs_mc2.update(stage1_kwargs)
|
||||
|
||||
shared_experts = kwargs.get('shared_experts', None)
|
||||
if shared_experts:
|
||||
shared_hidden_states = kwargs.get('shared_hidden_states', None)
|
||||
with tng.scope.npu_stream_switch('cv'):
|
||||
tng.scope.npu_wait_tensor(shared_hidden_states, hidden_states)
|
||||
shared_x, shared_dynamic_scale = torch_npu.npu_dynamic_quant(
|
||||
shared_hidden_states)
|
||||
shared_gate_up = torch_npu.npu_quant_matmul(
|
||||
shared_x,
|
||||
shared_experts.gate_up_proj.weight,
|
||||
shared_experts.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
kwargs.update({
|
||||
"shared_gate_up": shared_gate_up,
|
||||
"shared_dynamic_scale": shared_dynamic_scale,
|
||||
})
|
||||
|
||||
output = torch_npu.npu_moe_distribute_dispatch(**kwargs_mc2)
|
||||
# comm_stream.wait_stream(torch.npu.current_stream())
|
||||
expand_x, dynamic_scale, expand_idx, expert_token_nums, ep_recv_counts = output[
|
||||
0:5]
|
||||
|
||||
if quant_mode == 0:
|
||||
dynamic_scale = None
|
||||
if shared_experts is not None:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(hidden_states, topk_weights)
|
||||
shared_gate_up, _ = shared_experts.gate_up_proj(hidden_states)
|
||||
npu_wait_tensor(shared_gate_up[0], expand_x)
|
||||
shared_act = shared_experts.act_fn(shared_gate_up)
|
||||
|
||||
# `expand_x` will be disposed in the `apply_mlp` function
|
||||
down_out_list = apply_mlp(expand_x,
|
||||
@@ -221,12 +177,7 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
w2,
|
||||
w2_scale,
|
||||
expert_token_nums,
|
||||
dynamic_scale=dynamic_scale,
|
||||
**kwargs)
|
||||
|
||||
multi_stream = isinstance(down_out_list, tuple)
|
||||
if multi_stream:
|
||||
down_out_list, shared_output = down_out_list
|
||||
dynamic_scale=dynamic_scale)
|
||||
|
||||
# moeCombine
|
||||
kwargs_mc2 = {
|
||||
@@ -257,9 +208,13 @@ def fused_experts_with_mc2(hidden_states: torch.Tensor,
|
||||
|
||||
hidden_states = torch_npu.npu_moe_distribute_combine(**kwargs_mc2)
|
||||
|
||||
if multi_stream:
|
||||
if shared_experts is None:
|
||||
return hidden_states
|
||||
else:
|
||||
with npu_stream_switch("moe_secondary", 0):
|
||||
npu_wait_tensor(shared_act[0], down_out_list)
|
||||
shared_output, _ = shared_experts.down_proj(shared_act)
|
||||
return hidden_states, shared_output
|
||||
return hidden_states
|
||||
|
||||
|
||||
# currently expert parallelism implemented with all2all
|
||||
@@ -541,21 +496,33 @@ class AscendW8A8DynamicLinearMethod:
|
||||
@staticmethod
|
||||
def apply(
|
||||
layer: torch.nn.Module,
|
||||
x: torch.Tensor,
|
||||
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
|
||||
bias: Optional[torch.Tensor] = None,
|
||||
tp_rank: Optional[int] = 0,
|
||||
) -> torch.Tensor:
|
||||
original_dtype = x.dtype
|
||||
# use ATB quantize
|
||||
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
return torch_npu.npu_quant_matmul(
|
||||
quant_out,
|
||||
config = getattr(layer, "_ascend_quant_config", {})
|
||||
if not isinstance(x, tuple):
|
||||
output_dtype = config.get("output_dtype", x.dtype)
|
||||
quantized_x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
else:
|
||||
assert "output_dtype" in config.keys(), (
|
||||
f"DynamicLinearMethod needs explicitly specified `output_dtype`"
|
||||
f"for pre-quantized input, got config [{config}]")
|
||||
output_dtype = config["output_dtype"]
|
||||
quantized_x, dynamic_scale = x
|
||||
pertoken_scale = (dynamic_scale
|
||||
if config.get("pertoken_scale", True) else None)
|
||||
|
||||
output = torch_npu.npu_quant_matmul(
|
||||
quantized_x,
|
||||
layer.weight,
|
||||
layer.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
pertoken_scale=pertoken_scale,
|
||||
bias=bias,
|
||||
output_dtype=original_dtype,
|
||||
output_dtype=output_dtype,
|
||||
)
|
||||
return ((output, dynamic_scale)
|
||||
if config.get("return_scale", False) else output)
|
||||
|
||||
def process_weights_after_loading(self, layer):
|
||||
if self.transpose_weight:
|
||||
@@ -650,6 +617,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
enable_force_load_balance: bool = True,
|
||||
log2phy: torch.Tensor = None,
|
||||
global_redundant_expert_num: int = 0,
|
||||
shared_experts: Optional[Any] = None,
|
||||
**kwargs,
|
||||
) -> torch.Tensor:
|
||||
assert router_logits.shape[
|
||||
@@ -706,7 +674,7 @@ class AscendW8A8DynamicFusedMoEMethod:
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name,
|
||||
log2phy=log2phy,
|
||||
global_redundant_expert_num=global_redundant_expert_num,
|
||||
**kwargs)
|
||||
shared_experts=shared_experts)
|
||||
elif self.torchair_graph_enabled or self.ep_group.world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
|
||||
Reference in New Issue
Block a user