feat: support data parallel for deepseek (#1012)

### What this PR does / why we need it?
feat: support data parallel for deepseek

### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek

### How was this patch tested?

```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --max-num-seqs 24 \
    --max-model-len 4096 \
    --max-num-batched-tokens 4096 \
    --block-size 128 \
    -O 0 \
    --no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
    --gpu-memory-utilization 0.95 &> run.log &
disown
```

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-06-04 18:31:41 +08:00
committed by GitHub
parent 517811449e
commit da9acfca60
8 changed files with 212 additions and 88 deletions

View File

@@ -212,6 +212,14 @@ class CustomDeepseekV2MoE(nn.Module):
self.tp_group = get_tp_group().device_group
self.tp_rank = get_tp_group().rank_in_group
self.params_dtype = torch.get_default_dtype()
self.enable_graph_mode = False
additional_config = get_current_vllm_config().additional_config
if additional_config:
self.enable_graph_mode = additional_config.get(
"enable_graph_mode", False)
def forward(
self,
hidden_states: torch.Tensor,
@@ -228,33 +236,35 @@ class CustomDeepseekV2MoE(nn.Module):
else:
is_prefill = attn_metadata.num_prefills > 0
enable_force_load_balance = False
num_tokens, hidden_dim = hidden_states.shape
if hasattr(attn_metadata, 'with_prefill_across_dp'):
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
num_tokens, hidden_size = hidden_states.shape
if self.n_shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
if self.tp_size > 1:
# pass
num_tokens, hidden_size = hidden_states.shape
if num_tokens < self.tp_size:
target_size = self.tp_size
new_hidden_states = torch.empty([target_size, hidden_size],
dtype=hidden_states.dtype,
device=hidden_states.device)
new_hidden_states[:num_tokens] = hidden_states
hidden_states = new_hidden_states
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
local_hidden_states = chunk_hidden_states[self.tp_rank]
else:
local_hidden_states = hidden_states
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
hidden_states = chunks[self.tp_rank]
elif not self.enable_graph_mode:
num_padding_tokens = (self.tp_size -
num_tokens % self.tp_size) % self.tp_size
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
if num_padding_tokens > 0:
hidden_states = nn.functional.pad(
hidden_states, (0, 0, 0, num_padding_tokens))
chunk_hidden_states = torch.tensor_split(hidden_states,
self.tp_size,
dim=0)
hidden_states = chunk_hidden_states[self.tp_rank]
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(local_hidden_states)
router_logits, _ = self.gate(hidden_states)
router_hidden_states = self.experts(
hidden_states=local_hidden_states,
hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
is_prefill=is_prefill,
top_k=CustomDeepseekV2MoE.top_k,
@@ -262,18 +272,29 @@ class CustomDeepseekV2MoE(nn.Module):
) * self.routed_scaling_factor
if self.tp_size > 1:
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
self.tp_group)
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_tokens < self.tp_size:
final_hidden_states = final_hidden_states[:num_tokens]
else:
final_hidden_states = router_hidden_states
if self.enable_graph_mode:
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
final_hidden_states = torch.zeros(
[num_tokens, hidden_size],
dtype=self.params_dtype,
device="npu")
dist.all_gather_into_tensor(final_hidden_states,
hidden_states, self.tp_group)
hidden_states = final_hidden_states
else:
hidden_states = tensor_model_parallel_all_reduce(
hidden_states)
else:
dist.all_gather(list(chunk_hidden_states), hidden_states,
self.tp_group)
hidden_states = torch.cat(chunk_hidden_states, dim=0)
if num_padding_tokens > 0:
hidden_states = hidden_states[:-num_padding_tokens]
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
hidden_states = hidden_states + shared_output
return final_hidden_states.view(num_tokens, hidden_dim)
return hidden_states.view(num_tokens, hidden_size)
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):