feat: support data parallel for deepseek (#1012)
### What this PR does / why we need it?
feat: support data parallel for deepseek
### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--max-num-seqs 24 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--block-size 128 \
-O 0 \
--no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -212,6 +212,14 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
self.tp_group = get_tp_group().device_group
|
||||
self.tp_rank = get_tp_group().rank_in_group
|
||||
|
||||
self.params_dtype = torch.get_default_dtype()
|
||||
|
||||
self.enable_graph_mode = False
|
||||
additional_config = get_current_vllm_config().additional_config
|
||||
if additional_config:
|
||||
self.enable_graph_mode = additional_config.get(
|
||||
"enable_graph_mode", False)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
@@ -228,33 +236,35 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
else:
|
||||
is_prefill = attn_metadata.num_prefills > 0
|
||||
enable_force_load_balance = False
|
||||
num_tokens, hidden_dim = hidden_states.shape
|
||||
if hasattr(attn_metadata, 'with_prefill_across_dp'):
|
||||
is_prefill = is_prefill or attn_metadata.with_prefill_across_dp
|
||||
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
|
||||
if self.n_shared_experts is not None:
|
||||
shared_output = self.shared_experts(hidden_states)
|
||||
|
||||
if self.tp_size > 1:
|
||||
# pass
|
||||
num_tokens, hidden_size = hidden_states.shape
|
||||
if num_tokens < self.tp_size:
|
||||
target_size = self.tp_size
|
||||
new_hidden_states = torch.empty([target_size, hidden_size],
|
||||
dtype=hidden_states.dtype,
|
||||
device=hidden_states.device)
|
||||
new_hidden_states[:num_tokens] = hidden_states
|
||||
hidden_states = new_hidden_states
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
local_hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
else:
|
||||
local_hidden_states = hidden_states
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
chunks = torch.chunk(hidden_states, self.tp_size, dim=0)
|
||||
hidden_states = chunks[self.tp_rank]
|
||||
elif not self.enable_graph_mode:
|
||||
num_padding_tokens = (self.tp_size -
|
||||
num_tokens % self.tp_size) % self.tp_size
|
||||
# Pad hidden_states to make it divisible by tp_size to avoid cross-ring AllGatherV on 910B2C
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = nn.functional.pad(
|
||||
hidden_states, (0, 0, 0, num_padding_tokens))
|
||||
chunk_hidden_states = torch.tensor_split(hidden_states,
|
||||
self.tp_size,
|
||||
dim=0)
|
||||
hidden_states = chunk_hidden_states[self.tp_rank]
|
||||
|
||||
# router_logits: (num_tokens, n_experts)
|
||||
router_logits, _ = self.gate(local_hidden_states)
|
||||
router_logits, _ = self.gate(hidden_states)
|
||||
|
||||
router_hidden_states = self.experts(
|
||||
hidden_states=local_hidden_states,
|
||||
hidden_states = self.experts(
|
||||
hidden_states=hidden_states,
|
||||
router_logits=router_logits,
|
||||
is_prefill=is_prefill,
|
||||
top_k=CustomDeepseekV2MoE.top_k,
|
||||
@@ -262,18 +272,29 @@ class CustomDeepseekV2MoE(nn.Module):
|
||||
) * self.routed_scaling_factor
|
||||
|
||||
if self.tp_size > 1:
|
||||
dist.all_gather(list(chunk_hidden_states), router_hidden_states,
|
||||
self.tp_group)
|
||||
final_hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_tokens < self.tp_size:
|
||||
final_hidden_states = final_hidden_states[:num_tokens]
|
||||
else:
|
||||
final_hidden_states = router_hidden_states
|
||||
if self.enable_graph_mode:
|
||||
if envs_ascend.VLLM_ENABLE_MC2 and not is_prefill:
|
||||
final_hidden_states = torch.zeros(
|
||||
[num_tokens, hidden_size],
|
||||
dtype=self.params_dtype,
|
||||
device="npu")
|
||||
dist.all_gather_into_tensor(final_hidden_states,
|
||||
hidden_states, self.tp_group)
|
||||
hidden_states = final_hidden_states
|
||||
else:
|
||||
hidden_states = tensor_model_parallel_all_reduce(
|
||||
hidden_states)
|
||||
else:
|
||||
dist.all_gather(list(chunk_hidden_states), hidden_states,
|
||||
self.tp_group)
|
||||
hidden_states = torch.cat(chunk_hidden_states, dim=0)
|
||||
if num_padding_tokens > 0:
|
||||
hidden_states = hidden_states[:-num_padding_tokens]
|
||||
|
||||
if shared_output is not None:
|
||||
final_hidden_states = final_hidden_states + shared_output
|
||||
hidden_states = hidden_states + shared_output
|
||||
|
||||
return final_hidden_states.view(num_tokens, hidden_dim)
|
||||
return hidden_states.view(num_tokens, hidden_size)
|
||||
|
||||
|
||||
class CustomDeepseekV2MLAAttention(DeepseekV2MLAAttention):
|
||||
|
||||
Reference in New Issue
Block a user