feat: support data parallel for deepseek (#1012)

### What this PR does / why we need it?
feat: support data parallel for deepseek

### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek

### How was this patch tested?

```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1

source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh

nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
    --quantization ascend \
    --served-model-name auto \
    --trust-remote-code \
    --distributed-executor-backend=mp \
    --port 8006 \
    -tp=8 \
    -dp=2 \
    --max-num-seqs 24 \
    --max-model-len 4096 \
    --max-num-batched-tokens 4096 \
    --block-size 128 \
    -O 0 \
    --no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
    --gpu-memory-utilization 0.95 &> run.log &
disown
```

Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
NeverRaR
2025-06-04 18:31:41 +08:00
committed by GitHub
parent 517811449e
commit da9acfca60
8 changed files with 212 additions and 88 deletions

View File

@@ -117,6 +117,8 @@ class AscendMLAMetadata:
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
with_prefill_across_dp: bool = False
# The dimension of the attention heads
head_dim: Optional[int] = None
attn_mask: torch.Tensor = None
@@ -260,6 +262,10 @@ class AscendMLAMetadataBuilder:
PAD_SLOT_ID,
dtype=torch.int32,
device=device)
query_start_loc = torch.full((num_reqs, ),
-1,
dtype=torch.int32,
device=device)
decode_metadata = AscendMLADecodeMetadata(
input_positions=input_positions,
block_table=block_table,
@@ -278,15 +284,21 @@ class AscendMLAMetadataBuilder:
attn_state=AscendAttentionState.DecodeOnly,
prefill=None,
decode=decode_metadata,
query_start_loc=query_start_loc,
seq_lens=seq_lens,
block_tables=block_table,
)
def build(self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1) -> AscendMLAMetadata:
def build(
self,
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1,
with_prefill_across_dp: bool = False,
) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
# Note(simon): be careful about the CPU <> GPU memory movement in this
@@ -388,6 +400,7 @@ class AscendMLAMetadataBuilder:
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
with_prefill_across_dp=with_prefill_across_dp,
)
@@ -621,7 +634,7 @@ class AscendMLAImpl(MLAAttentionImpl):
kv = self.kv_a_proj_with_mqa(hidden_states)[0]
# npu_kv_rmsnorm_rope_cache needs [B, N, S, D]
kv = kv.view(B, N, S, self.kv_lora_rank + self.qk_rope_head_dim)
k_pe, k_nope, _, _ = torch.ops.npu_inference.npu_kv_rmsnorm_rope_cache(
k_pe, k_nope, _, _ = torch_npu.npu_kv_rmsnorm_rope_cache(
kv,
self.kv_a_layernorm.weight,
cos,
@@ -643,7 +656,7 @@ class AscendMLAImpl(MLAAttentionImpl):
B, N, D = x.shape
S = 1
x = x.view(B, N, S, D)
x = torch.ops.npu_inference.npu_interleave_rope(x, cos, sin)
x = torch_npu.npu_interleave_rope(x, cos, sin)
return x.view(B, N, D)
def _forward_decode(
@@ -766,6 +779,7 @@ class AscendMLAImpl(MLAAttentionImpl):
sin = sin[attn_metadata.decode.input_positions]
cos = cos[:, None, None, :]
sin = sin[:, None, None, :]
decode_q_pe = self.rope_single(decode_q_pe, cos, sin)
decode_k_pe, decode_k_nope = self.exec_kv(
hidden_states_or_kv_c_normed, cos, sin, kv_cache,