Support multistream of shared experts in FusedMoE (#997)
Contains on #1111 for completeness. <!-- Thanks for sending a pull request! BEFORE SUBMITTING, PLEASE READ https://docs.vllm.ai/en/latest/contributing/overview.html --> ### What this PR does / why we need it? Implement multi-stream parallelism for MoE layers with shared experts, where computation of shared experts will be overlapped with expert token dispatch and combine. Also, when multi-stream is enabled, weights of shared experts will be force to replicate across all cards, regardless of any tensor parallelism configurations, to avoid AllReduce operations. With the expected overlaping being: ``` | shared gate_up | shared act | | shared down | | dispatch | routed gate_up, act, down | combine | ``` <!-- - Please clarify what changes you are proposing. The purpose of this section is to outline the changes and how this PR fixes the issue. If possible, please consider writing useful notes for better and faster reviews in your PR. - Please clarify why the changes are needed. For instance, the use case and bug description. - Fixes # --> ### Does this PR introduce _any_ user-facing change? No. <!-- Note that it means *any* user-facing change including all aspects such as API, interface or other behavior changes. Documentation-only updates are not considered user-facing changes. --> ### How was this patch tested? Tested on 1x16 910 node, with tailored 2 layer DSKv2. <!-- CI passed with new added/existing test. If it was tested in a way different from regular unit tests, please clarify how you tested step by step, ideally copy and paste-able, so that other reviewers can test and check, and descendants can verify in the future. If tests were not added, please describe why they were not added and/or why it was difficult to add. --> --------- Signed-off-by: sdmyzlp <lrwei2@petalmail.com>
This commit is contained in:
@@ -29,7 +29,7 @@ from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch_npu
|
||||
import torch_npu # noqa: F401
|
||||
import vllm.envs as envs
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig
|
||||
@@ -40,13 +40,10 @@ from vllm.distributed import (get_pp_group,
|
||||
get_tp_group, tensor_model_parallel_all_reduce)
|
||||
from vllm.distributed.parallel_state import get_dp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.model_executor.layers.activation import SiluAndMul
|
||||
from vllm.model_executor.layers.layernorm import RMSNorm
|
||||
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
ReplicatedLinear,
|
||||
RowParallelLinear,
|
||||
UnquantizedLinearMethod)
|
||||
RowParallelLinear)
|
||||
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.layers.rotary_embedding import get_rope
|
||||
@@ -67,6 +64,7 @@ from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.models.deepseek_v2 import CustomDeepseekV2MLP
|
||||
from vllm_ascend.multistream.base import MSEventKey
|
||||
from vllm_ascend.multistream.context import (
|
||||
advance_step_multistream_layer_context, get_multistream_comm_context,
|
||||
@@ -78,117 +76,17 @@ from vllm_ascend.multistream.metadata import (MultiStreamConfig,
|
||||
make_multistream_metadata_ds)
|
||||
from vllm_ascend.multistream.ms_split import compute_split_seq_index
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ASCEND_ENABLE_DBO: bool = envs_ascend.VLLM_ASCEND_ENABLE_DBO
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekDBOMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int,
|
||||
intermediate_size: int,
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
|
||||
self.is_dynamic_quant = not isinstance(
|
||||
self.gate_up_proj.quant_method,
|
||||
UnquantizedLinearMethod) and isinstance(
|
||||
self.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
return x
|
||||
class CustomDeepseekDBOMLP(CustomDeepseekV2MLP):
|
||||
|
||||
def _forward_ms_mlp(self, x):
|
||||
current_ms_metadata = get_multistream_comm_context()
|
||||
assert current_ms_metadata is not None
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
with torch.npu.stream(current_ms_metadata.comm_stream):
|
||||
current_ms_metadata.before_comm_event.wait()
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
current_ms_metadata.after_comm_event.record()
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
current_ms_metadata.before_comm_event.record()
|
||||
|
||||
@@ -25,7 +25,7 @@
|
||||
# # vllm-project/vllm/vllm/model_executor/models/deepseek_v2.py
|
||||
# """Inference-only DeepseekV2/DeepseekV3 model."""
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
@@ -69,12 +69,73 @@ import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.ascend_config import get_ascend_config
|
||||
from vllm_ascend.distributed.parallel_state import get_ep_group
|
||||
from vllm_ascend.ops.fused_moe import AscendFusedMoE
|
||||
from vllm_ascend.quantization.quant_config import AscendLinearMethod
|
||||
from vllm_ascend.quantization.w8a8_dynamic import AscendW8A8DynamicLinearMethod
|
||||
from vllm_ascend.utils import dispose_tensor
|
||||
|
||||
VLLM_ENABLE_MC2: bool = envs_ascend.VLLM_ENABLE_MC2
|
||||
|
||||
|
||||
class CustomDeepseekV2SiluAndMul(SiluAndMul):
|
||||
|
||||
def __init__(self,
|
||||
*,
|
||||
weight_scale: Optional[Callable[[], torch.Tensor]] = None):
|
||||
super().__init__()
|
||||
self.weight_scale = weight_scale
|
||||
|
||||
def forward_oot(self, x: Union[torch.Tensor, Tuple[torch.Tensor,
|
||||
torch.Tensor]]):
|
||||
if isinstance(x, tuple):
|
||||
assert self.weight_scale is not None
|
||||
# For AscendW8A8DynamicLinearMethod:
|
||||
# a dynamic scale is passed along with the quantized value.
|
||||
quantized_x, dynamic_scale = x
|
||||
return torch_npu.npu_dequant_swiglu_quant(
|
||||
x=quantized_x,
|
||||
weight_scale=self.weight_scale(),
|
||||
activation_scale=dynamic_scale,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
else:
|
||||
return super().forward_oot(x)
|
||||
|
||||
|
||||
class CustomDeepseekV2MergedReplicatedLinear(ReplicatedLinear):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_sizes: list[int],
|
||||
bias: bool = True,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
):
|
||||
self.output_sizes = output_sizes
|
||||
super().__init__(input_size,
|
||||
sum(output_sizes),
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix)
|
||||
|
||||
def weight_loader(self, param: torch.nn.Parameter,
|
||||
loaded_weight: torch.Tensor, loaded_shard_id: int):
|
||||
# With no support for GGUF format yet.
|
||||
assert not getattr(param, "is_gguf_weight", False)
|
||||
assert not getattr(param, "is_gguf_weight_type", False)
|
||||
|
||||
assert loaded_shard_id < len(self.output_sizes)
|
||||
shard_offset = sum(self.output_sizes[:loaded_shard_id])
|
||||
shard_size = self.output_sizes[loaded_shard_id]
|
||||
shard = param.data.narrow(param.output_dim, shard_offset, shard_size)
|
||||
|
||||
assert shard.size() == loaded_weight.size(), (
|
||||
f"Tried to load weights of size {loaded_weight.size()}"
|
||||
f"to a parameter shard of id {loaded_shard_id} size {shard.size()}"
|
||||
)
|
||||
shard.copy_(loaded_weight)
|
||||
|
||||
|
||||
class CustomDeepseekV2MLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
@@ -84,61 +145,68 @@ class CustomDeepseekV2MLP(nn.Module):
|
||||
hidden_act: str,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
reduce_results: bool = True,
|
||||
force_replicate: bool = False,
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if not force_replicate:
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = RowParallelLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
else:
|
||||
self.gate_up_proj = CustomDeepseekV2MergedReplicatedLinear(
|
||||
hidden_size, [intermediate_size] * 2,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.gate_up_proj")
|
||||
self.down_proj = ReplicatedLinear(intermediate_size,
|
||||
hidden_size,
|
||||
bias=False,
|
||||
quant_config=quant_config,
|
||||
prefix=f"{prefix}.down_proj")
|
||||
if hidden_act != "silu":
|
||||
raise ValueError(f"Unsupported activation: {hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
self.act_fn = SiluAndMul()
|
||||
|
||||
# NOTE: `torch_npu.npu_dequant_swiglu_quant` can only be enabled in dynamic quant
|
||||
self.is_dynamic_quant = not isinstance(
|
||||
self.gate_up_proj.quant_method,
|
||||
UnquantizedLinearMethod) and isinstance(
|
||||
self.gate_up_proj.quant_method.quant_method,
|
||||
AscendW8A8DynamicLinearMethod)
|
||||
quant_method = self.gate_up_proj.quant_method
|
||||
if isinstance(quant_method, UnquantizedLinearMethod):
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul()
|
||||
elif (isinstance(quant_method, AscendLinearMethod) and isinstance(
|
||||
quant_method.quant_method, AscendW8A8DynamicLinearMethod)):
|
||||
# TODO(sdmyzlp): Currently preserved as before:
|
||||
# 1. The only quantization supported for silu is W8A8Dynamic
|
||||
# 2. Output dtype of gate_up/down is fixed to be int32/bfloat16
|
||||
#
|
||||
# Maybe one can implement a better and more general configuration
|
||||
# scheme, e.g. by somehow passing around the tweaked `quant_config`
|
||||
self.act_fn = CustomDeepseekV2SiluAndMul(
|
||||
# Use lazy binding, for `weight_scale_fp32` is accessible
|
||||
# only after `process_weights_after_loading`.
|
||||
weight_scale=lambda: self.gate_up_proj.weight_scale_fp32)
|
||||
# To be consumed by AscendW8A8DynamicLinearMethod.apply()
|
||||
self.gate_up_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.int32,
|
||||
"pertoken_scale": False,
|
||||
"return_scale": True,
|
||||
}
|
||||
self.down_proj._ascend_quant_config = {
|
||||
"output_dtype": torch.bfloat16,
|
||||
"pertoken_scale": True,
|
||||
"return_scale": False,
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Quantization with [{type(quant_method)}] is NOT supported")
|
||||
|
||||
def forward(self, x):
|
||||
if self.is_dynamic_quant:
|
||||
x, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.gate_up_proj.weight,
|
||||
self.gate_up_proj.weight_scale,
|
||||
output_dtype=torch.int32,
|
||||
)
|
||||
x, dynamic_scale = torch_npu.npu_dequant_swiglu_quant(
|
||||
x=x,
|
||||
weight_scale=self.gate_up_proj.weight_scale_fp32,
|
||||
activation_scale=dynamic_scale,
|
||||
bias=None,
|
||||
quant_scale=None,
|
||||
quant_offset=None,
|
||||
group_index=None,
|
||||
activate_left=True,
|
||||
quant_mode=1)
|
||||
x = torch_npu.npu_quant_matmul(
|
||||
x,
|
||||
self.down_proj.weight,
|
||||
self.down_proj.weight_scale,
|
||||
pertoken_scale=dynamic_scale,
|
||||
output_dtype=torch.bfloat16,
|
||||
)
|
||||
if self.down_proj.reduce_results and self.down_proj.tp_size > 1:
|
||||
x = tensor_model_parallel_all_reduce(x)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
x = self.act_fn(gate_up)
|
||||
x, _ = self.down_proj(x)
|
||||
@@ -169,6 +237,12 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
raise ValueError(f"Unsupported activation: {config.hidden_act}. "
|
||||
"Only silu is supported for now.")
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_moe = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe and VLLM_ENABLE_MC2
|
||||
|
||||
self.gate = ReplicatedLinear(config.hidden_size,
|
||||
config.n_routed_experts,
|
||||
bias=False,
|
||||
@@ -204,8 +278,11 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=True,
|
||||
force_replicate=self.enable_multistream_moe,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
else:
|
||||
self.shared_experts = None # type: ignore
|
||||
CustomDeepseekV2MoE.top_k = config.num_experts_per_tok
|
||||
|
||||
self.dp_size = get_dp_group().world_size
|
||||
@@ -216,12 +293,6 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
ascend_config = get_ascend_config()
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
# NOTE: multistream only effective when `VLLM_ENABLE_MC2` is on
|
||||
self.enable_multistream_shared_expert = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_shared_expert and VLLM_ENABLE_MC2
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -240,12 +311,10 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
enable_force_load_balance = False
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
multistream = self.enable_multistream_shared_expert and not is_prefill
|
||||
|
||||
old_hidden_states = hidden_states.clone()
|
||||
old_hidden_states = hidden_states
|
||||
use_separated_shared_experts = (self.shared_experts is not None
|
||||
and not self.enable_multistream_moe)
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
@@ -262,25 +331,22 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
kwargs = {}
|
||||
if multistream:
|
||||
kwargs.update({
|
||||
"shared_experts": self.shared_experts,
|
||||
"shared_hidden_states": old_hidden_states
|
||||
})
|
||||
|
||||
hidden_states = self.experts(
|
||||
experts_hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
enable_force_load_balance=enable_force_load_balance,
|
||||
**kwargs)
|
||||
shared_experts=(self.shared_experts
|
||||
if not use_separated_shared_experts else None),
|
||||
)
|
||||
|
||||
if multistream:
|
||||
hidden_states, shared_output = hidden_states
|
||||
|
||||
hidden_states = hidden_states * self.routed_scaling_factor
|
||||
if not isinstance(experts_hidden_states, tuple):
|
||||
hidden_states = experts_hidden_states * self.routed_scaling_factor
|
||||
else:
|
||||
hidden_states = (
|
||||
experts_hidden_states[0] * self.routed_scaling_factor +
|
||||
experts_hidden_states[1])
|
||||
|
||||
if self.tp_size > 1:
|
||||
if (VLLM_ENABLE_MC2
|
||||
@@ -294,12 +360,9 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
if not multistream:
|
||||
shared_output = self.shared_experts(old_hidden_states)
|
||||
|
||||
if shared_output is not None:
|
||||
hidden_states = hidden_states + shared_output
|
||||
if use_separated_shared_experts:
|
||||
hidden_states = hidden_states + self.shared_experts(
|
||||
old_hidden_states)
|
||||
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user