feat: support data parallel for deepseek (#1012)
### What this PR does / why we need it?
feat: support data parallel for deepseek
### Does this PR introduce _any_ user-facing change?
Yes, support dp for deepseek
### How was this patch tested?
```
export VLLM_ENABLE_MC2=0
export VLLM_USE_V1=1
export TASK_QUEUE_ENABLE=1
source /usr/local/Ascend/ascend-toolkit/set_env.sh
source /usr/local/Ascend/nnal/atb/set_env.sh
nohup python -m vllm.entrypoints.openai.api_server
--model=/path/to/DeepSeek-R1-W8A8 \
--quantization ascend \
--served-model-name auto \
--trust-remote-code \
--distributed-executor-backend=mp \
--port 8006 \
-tp=8 \
-dp=2 \
--max-num-seqs 24 \
--max-model-len 4096 \
--max-num-batched-tokens 4096 \
--block-size 128 \
-O 0 \
--no-enable-prefix-caching \
--additional-config
'{"torchair_graph_batch_sizes":[24],"expert_tensor_parallel_size":16,"ascend_scheduler_config":{},"enable_graph_mode":true}'
\
--gpu-memory-utilization 0.95 &> run.log &
disown
```
Signed-off-by: boying <897013703@qq.com>
This commit is contained in:
@@ -29,12 +29,14 @@ import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torch._dynamo.cache_size
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from torch.distributed import ReduceOp
|
||||
from vllm.attention import AttentionType, get_attn_backend
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.config import CompilationLevel, VllmConfig
|
||||
from vllm.distributed import get_tensor_model_parallel_world_size
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.distributed.parallel_state import get_dp_group, get_pp_group
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.inputs import INPUT_REGISTRY
|
||||
from vllm.logger import logger
|
||||
@@ -361,6 +363,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
torch._logging.set_logs(
|
||||
recompiles=envs_ascend.VLLM_ASCEND_TRACE_RECOMPILES)
|
||||
|
||||
self.dp_size = vllm_config.parallel_config.data_parallel_size
|
||||
self.dp_rank = vllm_config.parallel_config.data_parallel_rank
|
||||
|
||||
def _update_states(self, scheduler_output: "SchedulerOutput") -> None:
|
||||
"""Update the cached states and the persistent batch with the scheduler
|
||||
output.
|
||||
@@ -512,6 +517,16 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if batch_changed:
|
||||
self.input_batch.refresh_sampling_metadata()
|
||||
|
||||
def _get_forward_metadata_across_dp(
|
||||
self, batch_size: int, with_prefill: bool) -> tuple[int, bool]:
|
||||
forward_metadata = torch.tensor([batch_size, with_prefill],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
dist.all_reduce(forward_metadata,
|
||||
op=ReduceOp.MAX,
|
||||
group=get_dp_group().cpu_group)
|
||||
return int(forward_metadata[0]), bool(forward_metadata[1] > 0)
|
||||
|
||||
def get_model(self) -> nn.Module:
|
||||
return self.model
|
||||
|
||||
@@ -648,12 +663,24 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
seq_lens = self.seq_lens[:num_reqs]
|
||||
common_attn_metadata = CommonAttentionMetadata(
|
||||
query_start_loc=query_start_loc, seq_lens=seq_lens)
|
||||
with_prefill = attn_state != AscendAttentionState.DecodeOnly
|
||||
|
||||
if self.dp_size > 1:
|
||||
max_num_tokens, with_prefill = self._get_forward_metadata_across_dp(
|
||||
total_num_scheduled_tokens, with_prefill)
|
||||
extra_builder_kwargs['with_prefill_across_dp'] = with_prefill
|
||||
|
||||
# Add graph_pad_size here
|
||||
if self.enable_torchair_graph_mode:
|
||||
batchsize = len(seq_lens)
|
||||
padded_batch_size = self.select_torchair_padded_batchsize(
|
||||
batchsize)
|
||||
graph_pad_size = padded_batch_size - batchsize
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or (self.enable_torchair_graph_mode
|
||||
and not with_prefill):
|
||||
batch_size = len(seq_lens)
|
||||
if self.dp_size > 1:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
max_num_tokens)
|
||||
else:
|
||||
padded_batch_size = self.select_torchair_padded_batch_size(
|
||||
batch_size)
|
||||
graph_pad_size = padded_batch_size - batch_size
|
||||
extra_builder_kwargs['graph_pad_size'] = graph_pad_size
|
||||
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
@@ -687,7 +714,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.input_ids_cpu[:total_num_scheduled_tokens], non_blocking=True)
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
if (envs_ascend.VLLM_ENABLE_MC2
|
||||
or self.enable_torchair_graph_mode) and not with_prefill:
|
||||
input_ids = self.input_ids[:padded_batch_size]
|
||||
positions = self.positions[:padded_batch_size]
|
||||
|
||||
@@ -699,7 +727,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
if self.enable_torchair_graph_mode:
|
||||
model_kwargs["kv_caches"] = self.kv_caches
|
||||
model_kwargs["attn_metadata"] = attn_metadata
|
||||
if self.enable_torchair_graph_mode and attn_metadata.attn_state == AscendAttentionState.DecodeOnly:
|
||||
if self.enable_torchair_graph_mode and not with_prefill:
|
||||
hidden_states = self.compile_model(
|
||||
input_ids=input_ids,
|
||||
positions=positions,
|
||||
@@ -1095,7 +1123,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self,
|
||||
num_tokens: int,
|
||||
is_compile: bool = False,
|
||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill,
|
||||
with_prefill: bool = True,
|
||||
) -> torch.Tensor:
|
||||
# Set num_scheduled_tokens based on num_tokens and max_num_seqs
|
||||
# for dummy run with LoRA so that the num_reqs collectively
|
||||
@@ -1139,8 +1167,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
for k, v in self.intermediate_tensors.items()
|
||||
})
|
||||
|
||||
with set_forward_context(None, self.vllm_config):
|
||||
if self.enable_torchair_graph_mode and attn_state == AscendAttentionState.DecodeOnly:
|
||||
with set_forward_context(None,
|
||||
self.vllm_config,
|
||||
num_tokens=num_tokens):
|
||||
if self.enable_torchair_graph_mode and not with_prefill:
|
||||
attn_metadata = self.attn_metadata_builder.build_dummy(
|
||||
num_reqs=num_tokens, num_actual_tokens=1)
|
||||
# Only mark static while compiling
|
||||
@@ -1393,7 +1423,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
logger.info(
|
||||
"Capturing torchair graph, this usually takes %.1f~%.1f mins.",
|
||||
0.5 * graph_num, 1.5 * graph_num)
|
||||
attn_state = AscendAttentionState.DecodeOnly
|
||||
# Trigger torchair graph capture for specific shapes.
|
||||
# Capture the large shapes first so that the smaller shapes
|
||||
# can reuse the memory pool allocated for the large shapes.
|
||||
@@ -1403,10 +1432,10 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
cudagraph_num_of_warmups):
|
||||
self._dummy_run(num_tokens,
|
||||
is_compile=True,
|
||||
attn_state=attn_state)
|
||||
with_prefill=False)
|
||||
self._dummy_run(num_tokens,
|
||||
is_compile=True,
|
||||
attn_state=attn_state)
|
||||
with_prefill=False)
|
||||
logger.info("Batchsize %d is compiled successfully: %d/%d.",
|
||||
num_tokens, idx + 1, graph_num)
|
||||
elif self.use_aclgraph:
|
||||
@@ -1551,9 +1580,9 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
self.torchair_graph_batch_sizes.append(largest_batch_size)
|
||||
largest_batch_size += batch_size_step
|
||||
|
||||
def select_torchair_padded_batchsize(self, batchsize: int):
|
||||
selected_batchsize = self.max_num_reqs
|
||||
for padded_batchsize in self.torchair_graph_batch_sizes:
|
||||
if batchsize <= padded_batchsize < selected_batchsize:
|
||||
selected_batchsize = padded_batchsize
|
||||
return selected_batchsize
|
||||
def select_torchair_padded_batch_size(self, batch_size: int):
|
||||
selected_batch_size = self.max_num_reqs
|
||||
for padded_batch_size in self.torchair_graph_batch_sizes:
|
||||
if batch_size <= padded_batch_size < selected_batch_size:
|
||||
selected_batch_size = padded_batch_size
|
||||
return selected_batch_size
|
||||
|
||||
@@ -544,7 +544,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
init_ascend_model_parallel(
|
||||
parallel_config.expert_parallel_size,
|
||||
parallel_config.expert_tensor_parallel_size,
|
||||
parallel_config.world_size,
|
||||
parallel_config.world_size_across_dp,
|
||||
)
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from vllm.v1.outputs import ModelRunnerOutput
|
||||
from vllm.v1.utils import bind_kv_cache
|
||||
from vllm.v1.worker.worker_base import WorkerBase
|
||||
|
||||
import vllm_ascend.envs as envs_ascend
|
||||
from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel
|
||||
from vllm_ascend.platform import NPUPlatform
|
||||
from vllm_ascend.utils import try_register_lib
|
||||
@@ -230,7 +231,18 @@ class NPUWorker(WorkerBase):
|
||||
return self.model_runner.pin_lora(lora_id)
|
||||
|
||||
def execute_dummy_batch(self) -> None:
|
||||
self.model_runner._dummy_run(1)
|
||||
runner = self.model_runner
|
||||
num_tokens = 1
|
||||
if runner.dp_size > 1:
|
||||
max_num_tokens, with_prefill = runner._get_forward_metadata_across_dp(
|
||||
1, False)
|
||||
if envs_ascend.VLLM_ENABLE_MC2 or runner.enable_torchair_graph_mode:
|
||||
if not with_prefill:
|
||||
num_tokens = max_num_tokens
|
||||
num_tokens = runner.select_torchair_padded_batch_size(num_tokens)
|
||||
runner._dummy_run(num_tokens,
|
||||
is_compile=False,
|
||||
with_prefill=with_prefill)
|
||||
|
||||
def _init_worker_distributed_environment(self) -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
@@ -246,7 +258,7 @@ class NPUWorker(WorkerBase):
|
||||
init_ascend_model_parallel(
|
||||
parallel_config.expert_parallel_size,
|
||||
parallel_config.expert_tensor_parallel_size,
|
||||
parallel_config.world_size,
|
||||
parallel_config.world_size_across_dp,
|
||||
)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user