feat: support data parallel for deepseek (#1012)
### What this PR does / why we need it?
feat: support data parallel for deepseek
### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--max-num-seqs 24 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--block-size 128 \
-O 0 \
--no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -587,6 +587,12 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
self.global_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
self.local_batch_size = self.global_batch_size // self.ep_size
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
try:
|
||||
device_group = ep_group.device_group
|
||||
# TODO: Try local_rank = ep_group.rank_in_group
|
||||
@@ -664,7 +670,7 @@ class AscendUnquantizedFusedMoEMethod(UnquantizedFusedMoEMethod):
|
||||
top_k=top_k,
|
||||
expert_map=expert_map,
|
||||
moe_all_to_all_group_name=self.moe_all_to_all_group_name)
|
||||
elif get_ep_group().world_size == 1:
|
||||
elif self.enable_graph_mode or get_ep_group().world_size == 1:
|
||||
return fused_experts(hidden_states=x,
|
||||
w1=layer.w13_weight,
|
||||
w2=layer.w2_weight,
|
||||
@@ -750,26 +756,20 @@ class AscendFusedMoE(FusedMoE):
|
||||
self.expert_map = None
|
||||
self.activation = activation
|
||||
|
||||
if self.ep_size > 1:
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
# Create a tensor of size num_experts filled with -1
|
||||
self.local_num_experts, self.expert_map = determine_expert_map(
|
||||
self.ep_size,
|
||||
get_ep_group().rank_in_group, self.global_num_experts)
|
||||
|
||||
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
|
||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||
self.moe_parallel_config.tp_rank = get_etp_group().rank_in_group
|
||||
self.moe_parallel_config.ep_rank = get_ep_group().rank_in_group
|
||||
|
||||
else:
|
||||
# Adjust TP size for DP attention
|
||||
# haven't test its functionality yet, may remove in the future
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
self.moe_parallel_config.tp_rank = self.tp_size * self.dp_rank
|
||||
self.moe_parallel_config.ep_rank = 0
|
||||
self.moe_parallel_config.tp_size = self.tp_size * self.dp_size
|
||||
self.moe_parallel_config.ep_size = 1
|
||||
|
||||
self.local_num_experts, self.expert_map = (self.global_num_experts,
|
||||
None)
|
||||
if self.scoring_func != "softmax" and not self.use_grouped_topk:
|
||||
raise ValueError("Only softmax scoring function is supported for "
|
||||
"non-grouped topk.")
|
||||
@@ -807,8 +807,15 @@ class AscendFusedMoE(FusedMoE):
|
||||
in ("GPTQMarlinMoEMethod", "CompressedTensorsWNA16MoEMethod")):
|
||||
moe_quant_params["intermediate_size_full"] = intermediate_size
|
||||
|
||||
self.ep_group = get_ep_group()
|
||||
self.quant_method.create_weights(layer=self, **moe_quant_params)
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def forward(self,
|
||||
hidden_states: torch.Tensor,
|
||||
router_logits: torch.Tensor,
|
||||
@@ -822,11 +829,28 @@ class AscendFusedMoE(FusedMoE):
|
||||
else:
|
||||
real_top_k = self.top_k
|
||||
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
# MC2 ag/rs broadcast/all_reduce
|
||||
# prefill_req x x √
|
||||
# decode_req √ x √
|
||||
# graph_mode √ √ x
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif self.enable_graph_mode:
|
||||
if USING_LCCL_COM: # type: ignore
|
||||
hidden_states = get_dp_group().all_gather(
|
||||
hidden_states, 0, False)
|
||||
router_logits = get_dp_group().all_gather(
|
||||
router_logits, 0, False)
|
||||
elif self.enable_graph_mode and not is_prefill:
|
||||
hidden_states = get_dp_group().all_gather(hidden_states, 0)
|
||||
router_logits = get_dp_group().all_gather(router_logits, 0)
|
||||
else:
|
||||
hidden_states, router_logits = get_ep_group().dispatch(
|
||||
hidden_states, router_logits)
|
||||
|
||||
# Matrix multiply.
|
||||
final_hidden_states = self.quant_method.apply(
|
||||
hidden_states = self.quant_method.apply(
|
||||
layer=self,
|
||||
x=hidden_states,
|
||||
router_logits=router_logits,
|
||||
@@ -843,11 +867,26 @@ class AscendFusedMoE(FusedMoE):
|
||||
is_prefill=is_prefill,
|
||||
enable_force_load_balance=enable_force_load_balance)
|
||||
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
if self.dp_size > 1:
|
||||
if VLLM_ENABLE_MC2 and not is_prefill:
|
||||
...
|
||||
elif self.enable_graph_mode:
|
||||
if USING_LCCL_COM: # type: ignore
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
elif self.enable_graph_mode and not is_prefill:
|
||||
hidden_states = dist._functional_collectives.reduce_scatter_tensor(
|
||||
hidden_states,
|
||||
"sum",
|
||||
scatter_dim=0,
|
||||
group=get_dp_group().device_group)
|
||||
else:
|
||||
hidden_states = get_ep_group().combine(hidden_states)
|
||||
|
||||
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
||||
final_hidden_states = tensor_model_parallel_all_reduce(
|
||||
final_hidden_states)
|
||||
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
||||
|
||||
return final_hidden_states
|
||||
return hidden_states
|
||||
|
||||
Reference in New Issue
Block a user