[Feature] Add PD separation feature (#432)

### What this PR does / why we need it?
Adapt Disaggregated Prefill feature onto Ascend device

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

The test usage has been provided alongwith the PR, in
examples/offline_disaggregated_prefill_npu.py
To run it, do this
```
export PROMPT_DEVICE_ID=0,1
export DECODE_DEVICE_ID=2,3
python examples/offline_disaggregated_prefill_npu.py
```

---------

Signed-off-by: ZihuiQian <qianzihui@huawei.com>
Co-authored-by: ZihuiQian <qianzihui@huawei.com>
This commit is contained in:
eeethenQ
2025-04-15 15:11:35 +08:00
committed by GitHub
parent c7f6584d75
commit 44a8301424
8 changed files with 634 additions and 8 deletions

View File

@@ -24,8 +24,9 @@ import torch
import torch.distributed
from torch import nn
from vllm import envs
from vllm.config import ParallelConfig, VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized,
from vllm.config import VllmConfig
from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment,
set_custom_all_reduce)
from vllm.logger import logger
@@ -161,8 +162,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment.
self._init_worker_distributed_environment(self.parallel_config,
self.rank,
self._init_worker_distributed_environment(self.vllm_config, self.rank,
self.distributed_init_method,
self.local_rank)
# Set random seed.
@@ -450,12 +450,13 @@ class NPUWorker(LocalOrDistributedWorkerBase):
def _init_worker_distributed_environment(
self,
parallel_config: ParallelConfig,
vllm_config: VllmConfig,
rank: int,
distributed_init_method: Optional[str] = None,
local_rank: int = -1,
backend: str = "hccl") -> None:
"""Initialize the distributed environment."""
parallel_config = self.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank,
@@ -463,6 +464,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,