[main][prefill optimization] Optimize parallel strategies to reduce communication overhead (#2198)
### What this PR does / why we need it?
1.Shared Expert Sharding Strategy Update: Switched from TP-aligned to
pure DP for shared experts, enabling more efficient execution.
2.O_Proj AllReduce → ReduceScatter: Reduced communication overhead by
using ReduceScatter, made possible by pure DP sharding.
3.AllGather Postponed: Delayed to after QKV down projection to reduce
synchronization impact during prefill.
### How was this patch tested?
Adding ut case in `tests/ut/attention/test_mla_v1.py`
#### How to run
use parameter `--additional_config='{"enable_shared_expert_dp": true}'`
##### a.How to run eager mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384 --enforce-eager
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true}'
##### b.How to run graph mode
eg:
python -m vllm.entrypoints.openai.api_server --model=/model_path
--trust-remote-code -tp 8 -dp 2 --enable_expert_parallel --port 8002
--max-model-len 5120 --max-num-batched-tokens 16384
--disable-log-requests
--additional_config='{"ascend_scheduler_config":{"enabled":true},"enable_shared_expert_dp":
true,"chunked_prefill_for_mla":true,"torchair_graph_config":{"enabled":true}}'
- vLLM version: v0.10.0
- vLLM main:
9edd1db02b
---------
Signed-off-by: Wang Kunpeng <1289706727@qq.com>
Signed-off-by: SlightwindSec <slightwindsec@gmail.com>
Co-authored-by: SlightwindSec <slightwindsec@gmail.com>
This commit is contained in:
@@ -141,7 +141,8 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill=True
|
||||
is_prefill=True,
|
||||
is_force_scatter=False
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
@@ -160,7 +161,13 @@ class CustomDeepseekV2RowParallelLinearReplaceAllreduce(RowParallelLinear):
|
||||
input_parallel,
|
||||
bias=bias_)
|
||||
if self.reduce_results and self.tp_size > 1:
|
||||
if not is_prefill and output_parallel.shape[0] % self.tp_size == 0:
|
||||
num_tokens = output_parallel.shape[0]
|
||||
if is_force_scatter and num_tokens % self.tp_size:
|
||||
output_parallel = nn.functional.pad(
|
||||
output_parallel, (0, 0, 0, -num_tokens % self.tp_size))
|
||||
if is_force_scatter or (not is_prefill
|
||||
and output_parallel.shape[0] % self.tp_size
|
||||
== 0):
|
||||
output = tensor_model_parallel_reduce_scatter(output_parallel,
|
||||
dim=0)
|
||||
else:
|
||||
@@ -180,7 +187,8 @@ class CustomDeepseekV2RowParallelLinear(RowParallelLinear):
|
||||
def forward(
|
||||
self,
|
||||
input_,
|
||||
is_prefill=True
|
||||
is_prefill=True,
|
||||
is_force_scatter=False
|
||||
) -> Union[torch.Tensor, tuple[torch.Tensor, Optional[nn.Parameter]]]:
|
||||
if self.input_is_parallel:
|
||||
input_parallel = input_
|
||||
@@ -347,13 +355,15 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
reduce_results = not self.all_reduce_merge
|
||||
intermediate_size = (config.moe_intermediate_size *
|
||||
config.n_shared_experts)
|
||||
enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
self.shared_experts = CustomDeepseekV2MLP(
|
||||
hidden_size=config.hidden_size,
|
||||
intermediate_size=intermediate_size,
|
||||
hidden_act=config.hidden_act,
|
||||
quant_config=quant_config,
|
||||
reduce_results=reduce_results,
|
||||
force_replicate=self.enable_multistream_moe,
|
||||
force_replicate=self.enable_multistream_moe
|
||||
or enable_shared_expert_dp,
|
||||
prefix=f"{prefix}.shared_experts",
|
||||
)
|
||||
else:
|
||||
@@ -447,9 +457,11 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
self.kv_lora_rank = kv_lora_rank
|
||||
|
||||
self.num_heads = num_heads
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % tp_size == 0
|
||||
self.num_local_heads = num_heads // tp_size
|
||||
self.tp_size = get_tensor_model_parallel_world_size()
|
||||
assert num_heads % self.tp_size == 0
|
||||
self.num_local_heads = num_heads // self.tp_size
|
||||
self.layers = config.num_hidden_layers
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
|
||||
self.scaling = self.qk_head_dim**-0.5
|
||||
self.rope_theta = rope_theta
|
||||
@@ -462,6 +474,7 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
self.torchair_graph_enabled = ascend_config.torchair_graph_config.enabled
|
||||
self.enable_multistream_mla = \
|
||||
ascend_config.torchair_graph_config.enable_multistream_mla
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
if self.q_lora_rank is not None:
|
||||
self.q_a_proj = ReplicatedLinear(self.hidden_size,
|
||||
@@ -501,8 +514,9 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
prefix=f"{prefix}.kv_b_proj")
|
||||
if (config.n_routed_experts is not None
|
||||
and self.debug_layer_idx >= config.first_k_dense_replace
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0 and
|
||||
ascend_config.torchair_graph_config.enable_multistream_moe):
|
||||
and self.debug_layer_idx % config.moe_layer_freq == 0
|
||||
and (ascend_config.torchair_graph_config.enable_multistream_moe
|
||||
or self.enable_shared_expert_dp)):
|
||||
self.o_proj = CustomDeepseekV2RowParallelLinearReplaceAllreduce(
|
||||
self.num_heads * self.v_head_dim,
|
||||
self.hidden_size,
|
||||
@@ -596,13 +610,27 @@ class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
output = output.view(-1, output_shape[-1])
|
||||
return output
|
||||
else:
|
||||
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
|
||||
kv_no_split = self.kv_a_proj_with_mqa(hidden_states)[0]
|
||||
if self.enable_shared_expert_dp and self.debug_layer_idx > self.first_k_dense_replace and self.debug_layer_idx < self.layers:
|
||||
hidden_states_or_q_c = get_tp_group().all_gather(
|
||||
hidden_states_or_q_c, 0)
|
||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||
|
||||
kv_c, k_pe = kv_no_split.split(
|
||||
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
|
||||
if not self.enable_shared_expert_dp or self.debug_layer_idx < self.first_k_dense_replace:
|
||||
output_shape = hidden_states.shape
|
||||
else:
|
||||
num_tokens = hidden_states_or_q_c.shape[0]
|
||||
rows = num_tokens // self.tp_size
|
||||
if num_tokens % self.tp_size:
|
||||
rows += 1
|
||||
output_shape = (rows, hidden_states.shape[1])
|
||||
return self.mla_attn(hidden_states_or_q_c,
|
||||
kv_c_normed,
|
||||
k_pe,
|
||||
output_shape=hidden_states.shape)
|
||||
output_shape=output_shape)
|
||||
|
||||
|
||||
class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
@@ -677,6 +705,8 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
eps=config.rms_norm_eps)
|
||||
self.routed_scaling_factor = config.routed_scaling_factor
|
||||
self.first_k_dense_replace = config.first_k_dense_replace
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.enable_shared_expert_dp = ascend_config.enable_shared_expert_dp
|
||||
|
||||
def forward(
|
||||
self,
|
||||
@@ -731,6 +761,18 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
# first layer.
|
||||
residual *= 1. / self.routed_scaling_factor
|
||||
|
||||
tp_size = get_tensor_model_parallel_world_size()
|
||||
if self.enable_shared_expert_dp and (
|
||||
self.layer_idx == self.first_k_dense_replace
|
||||
or self.layer_idx == self.layers) and tp_size > 1:
|
||||
num_tokens, _ = residual.shape
|
||||
if num_tokens % tp_size:
|
||||
residual = nn.functional.pad(residual,
|
||||
(0, 0, 0, -num_tokens % tp_size))
|
||||
chunk_residual = torch.tensor_split(residual, tp_size, dim=0)
|
||||
tp_rank = get_tensor_model_parallel_rank()
|
||||
residual = chunk_residual[tp_rank]
|
||||
|
||||
# Fully Connected
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual)
|
||||
@@ -756,6 +798,22 @@ class CustomDeepseekV2DecoderLayer(DeepseekV2DecoderLayer):
|
||||
dim=0)
|
||||
residual = tensor_model_parallel_all_gather(residual, dim=0)
|
||||
|
||||
# for last layer of main model and mtp layer.
|
||||
if self.enable_shared_expert_dp and self.layer_idx >= (
|
||||
self.layers - 1) and tp_size > 1:
|
||||
hidden_states = get_tp_group().all_gather(hidden_states, 0)
|
||||
residual = get_tp_group().all_gather(residual, 0)
|
||||
|
||||
attn_metadata = get_forward_context().attn_metadata
|
||||
if attn_metadata is not None:
|
||||
num_tokens = attn_metadata.num_actual_tokens
|
||||
else:
|
||||
num_tokens = hidden_states.shape[0]
|
||||
|
||||
if num_tokens < hidden_states.shape[0]:
|
||||
hidden_states = hidden_states[:num_tokens]
|
||||
residual = residual[:num_tokens]
|
||||
|
||||
return hidden_states, residual
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user