[Feature] Add PD separation feature (#432)

### What this PR does / why we need it?
Adapt Disaggregated Prefill feature onto Ascend device

### Does this PR introduce _any_ user-facing change?

no

### How was this patch tested?

The test usage has been provided alongwith the PR, in
examples/offline_disaggregated_prefill_npu.py
To run it, do this
```
export PROMPT_DEVICE_ID=0,1
export DECODE_DEVICE_ID=2,3
python examples/offline_disaggregated_prefill_npu.py
```

---------

Signed-off-by: ZihuiQian <qianzihui@huawei.com>
Co-authored-by: ZihuiQian <qianzihui@huawei.com>
This commit is contained in:
eeethenQ
2025-04-15 15:11:35 +08:00
committed by GitHub
parent c7f6584d75
commit 44a8301424
8 changed files with 634 additions and 8 deletions

View File

@@ -0,0 +1,140 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import multiprocessing as mp
import os
import time
from multiprocessing import Event, Process
def clean_up():
import gc
import torch
from vllm.distributed.parallel_state import (
destroy_distributed_environment, destroy_model_parallel)
destroy_model_parallel()
destroy_distributed_environment()
gc.collect()
torch.npu.empty_cache()
def run_prefill(prefill_done, process_close):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
]
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
)
# Set NPU memory utilization to 0.8
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2)
llm.generate(prompts, sampling_params)
print("Prefill node is finished.")
prefill_done.set()
# To keep the prefill node running in case the decode node is not done
# otherwise, the script might exit prematurely, causing incomplete decoding.
try:
while not process_close.is_set():
time.sleep(1)
except KeyboardInterrupt:
print("Script stopped by user.")
finally:
print("Cleanup prefill resources")
del llm
clean_up()
def run_decode(prefill_done):
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
from vllm import LLM, SamplingParams
from vllm.config import KVTransferConfig
prompts = [
"Hello, how are you today?", "Hi, what is your name?",
"Tell me a very long story.", "what is your favourite book?"
]
sampling_params = SamplingParams(temperature=0, top_p=0.95)
ktc = KVTransferConfig.from_cli(
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
)
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
kv_transfer_config=ktc,
max_model_len=2000,
gpu_memory_utilization=0.8,
tensor_parallel_size=2)
# Wait for the producer to start the consumer
print("Waiting for prefill node to finish...")
prefill_done.wait()
# At this point when the prefill_done is set, the kv-cache should have been
# transferred to this decode node, so we can start decoding.
outputs = llm.generate(prompts, sampling_params)
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
del llm
clean_up()
if __name__ == "__main__":
mp.get_context('spawn')
prefill_done = Event()
process_close = Event()
prefill_process = Process(target=run_prefill,
args=(
prefill_done,
process_close,
))
decode_process = Process(target=run_decode, args=(prefill_done, ))
# Start prefill node
prefill_process.start()
# Start decode node
decode_process.start()
# Terminate the prefill node when decode is finished
decode_process.join()
# Terminate prefill process
process_close.set()
prefill_process.join()
prefill_process.terminate()
print("All process done!")

View File

@@ -0,0 +1,6 @@
from vllm.distributed.kv_transfer.kv_connector.factory import \
KVConnectorFactory
KVConnectorFactory.register_connector(
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
"LLMDataDistConnector")

View File

@@ -0,0 +1,465 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import os
import re
import subprocess
from typing import TYPE_CHECKING, List, Tuple, Union
import torch
import torch_npu
import torchair # type: ignore
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
from vllm.logger import logger
from vllm.sequence import IntermediateTensors
import vllm_ascend.envs as envs
if TYPE_CHECKING:
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
import llm_datadist # type: ignore
TORCH_DTYPE_TO_NPU_DTYPE = {
torch.half: llm_datadist.DataType.DT_FLOAT16,
torch.float16: llm_datadist.DataType.DT_FLOAT16,
torch.bfloat16: llm_datadist.DataType.DT_BF16,
torch.float: llm_datadist.DataType.DT_FLOAT,
torch.float32: llm_datadist.DataType.DT_FLOAT,
torch.int8: llm_datadist.DataType.DT_INT8,
torch.int64: llm_datadist.DataType.DT_INT64,
torch.int32: llm_datadist.DataType.DT_INT32
}
# Get all device ips using hccn_tool
HCCN_TOOL_PATH = envs.HCCN_PATH
def get_device_ips():
world_size = 8
npu_info = subprocess.run(['npu-smi', 'info', '-m'],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True)
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
npu_start_idx = int(
re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1))
device_ip_list = []
for ip_offset in range(world_size):
cmd = [
HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g'
]
device_ip_info = subprocess.run(cmd,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
universal_newlines=True)
device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1)
device_ip_list.append(device_ip)
return device_ip_list
class KVTransferEngine:
def __init__(self, world_size, n_layer, role, local_rank):
self.world_size = world_size
self.n_layer = n_layer
self.role = role
self.device_ip_list = get_device_ips()
self.local_rank = local_rank
self.cluster_id = local_rank
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
prompt_device_ids = envs.PROMPT_DEVICE_ID
decode_device_ids = envs.DECODE_DEVICE_ID
if prompt_device_ids is None or decode_device_ids is None:
raise ValueError(
"Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID")
prompt_ids = [
int(x.strip()) for x in prompt_device_ids.split(",") if x.strip()
]
decode_ids = [
int(x.strip()) for x in decode_device_ids.split(",") if x.strip()
]
self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids]
self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids]
def prepare_data_dist(self):
options = {
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
}
if self.role == llm_datadist.LLMRole.PROMPT:
options["ge.exec.deviceId"] = str(self.local_rank)
options[
"llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}"
else:
options["ge.exec.deviceId"] = str(self.local_rank)
self.data_dist.init(options)
self.kv_transfer = self.data_dist.kv_cache_manager
logger.info(
f"{self.local_rank}/{self.world_size} rank data dist is ready")
def make_cluster(self, prefill_ip, cluster_id=-1):
cluster = llm_datadist.LLMClusterInfo()
cluster.remote_cluster_id = cluster_id
local_ip = self.decode_ip_list[self.local_rank]
remote_ip = prefill_ip
cluster.append_local_ip_info(local_ip, 0)
cluster.append_remote_ip_info(remote_ip, 26000)
return cluster
class LLMDataDistConnector(KVConnectorBase):
def __init__(
self,
rank: int,
local_rank: int,
config: VllmConfig,
):
self.config = config
self.tp_size = config.parallel_config.tensor_parallel_size
self.rank = rank
self.local_rank = local_rank
if self.config.kv_transfer_config.kv_role == "kv_producer":
self.role = llm_datadist.LLMRole.PROMPT
elif self.config.kv_transfer_config.kv_role == "kv_consumer":
self.role = llm_datadist.LLMRole.DECODER
else:
raise NotImplementedError(
"kv_role should be inside [kv_producer, kv_consumer]")
self.world_size = self.config.parallel_config.world_size
self.n_layer = self.config.model_config.get_num_layers(
self.config.parallel_config)
self.llm_datadist_engine = KVTransferEngine(self.world_size,
self.n_layer, self.role,
self.local_rank)
if self.role == llm_datadist.LLMRole.PROMPT:
self.llm_datadist_engine.prepare_data_dist()
else:
self.llm_datadist_engine.prepare_data_dist()
self.cluster = self.llm_datadist_engine.make_cluster(
self.llm_datadist_engine.prompt_ip_list[self.local_rank],
self.llm_datadist_engine.cluster_id)
_, ret = self.llm_datadist_engine.data_dist.link_clusters(
[self.cluster], 20000)
logger.info(f"local_rank {self.local_rank} link, ret={ret}")
def send_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor],
hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors]
) -> None:
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
num_layer = end_layer - start_layer
# Get shape of input_tokens_tensor and kv_cache
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype
kv_hidden_dtype = kv_caches[0].dtype
input_dtype = torch.int32
# initialize LLMDatadist data structure
key_desc = llm_datadist.CacheDesc(
num_layer,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=1)
value_desc = llm_datadist.CacheDesc(
num_layer,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=1)
input_desc = llm_datadist.CacheDesc(
1,
input_shape,
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
seq_len_dim_index=-1)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=-1)
key_cache_keys = [
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1)
]
value_cache_keys = [
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2)
]
input_cache_keys = [
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3)
]
hidden_cache_keys = [
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4)
]
self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
key_desc, key_cache_keys)
self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
value_desc, value_cache_keys)
self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
input_desc, input_cache_keys)
self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
hidden_desc, hidden_cache_keys)
key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0]
value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0]
input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0]
hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0]
self.key_cache = torchair.llm_datadist.create_npu_tensors(
key_desc.shape, kv_hidden_dtype, key_buffer_addr)
self.value_cache = torchair.llm_datadist.create_npu_tensors(
value_desc.shape, kv_hidden_dtype, value_buffer_addr)
self.input_cache = torchair.llm_datadist.create_npu_tensors(
input_desc.shape, input_dtype, input_buffer_addr)
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr)
indices = torch.tensor([0], dtype=torch.int64).npu()
# copy cache data into llm datadist cache using scatter update
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos].to(
torch.int32)
for layer_id in range(start_layer, end_layer):
kv_cache = kv_caches[layer_id - start_layer]
key_cache = kv_cache[0].view(-1, num_heads, head_size)
value_cache = kv_cache[1].view(-1, num_heads, head_size)
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
# copy key into datadist
k = self.key_cache[layer_id][:, start_pos:end_pos, :, :]
new_k = key_cache[current_slot_mapping].unsqueeze(0)
torch_npu.scatter_update_(k, indices, new_k, axis=-2)
# copy value into datadist
val = self.value_cache[layer_id][:, start_pos:end_pos, :, :]
new_val = value_cache[current_slot_mapping].unsqueeze(0)
torch_npu.scatter_update_(val, indices, new_val, axis=-2)
# copy input into datadist
inp = self.input_cache[0][:, start_pos:end_pos, :, :]
new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1)
torch_npu.scatter_update_(inp, indices, new_inp, axis=-2)
# copy hidden into datadist
hid = self.hidden_cache[0][:, start_pos:end_pos, :, :]
hid_shape0, hid_shape1 = hidden_or_intermediate_states[
start_pos:end_pos].shape
new_hid = hidden_or_intermediate_states[start_pos:end_pos].view(
1, hid_shape0, 1, hid_shape1)
torch_npu.scatter_update_(hid, indices, new_hid, axis=-2)
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
def recv_kv_caches_and_hidden_states(
self, model_executable: torch.nn.Module,
model_input: "ModelInputForGPUWithSamplingMetadata",
kv_caches: List[torch.Tensor]
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
"ModelInputForGPUWithSamplingMetadata"]:
bypass_model_exec = True
input_tokens_tensor = model_input.input_tokens
seq_lens = model_input.attn_metadata.seq_lens
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
hidden_or_intermediate_states_for_one_req = []
input_tokens_list = []
num_computed_tokens_list = []
start_pos_list = []
# get model config
start_layer = model_executable.model.start_layer
end_layer = model_executable.model.end_layer
model_config = model_executable.model.config
num_heads = int(model_config.num_key_value_heads / self.tp_size)
hidden_size = model_config.hidden_size
num_attention_heads = model_config.num_attention_heads
head_size = int(hidden_size / num_attention_heads)
num_layer = end_layer - start_layer
# get input_tensor_shape and hidden_shape
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
kv_hidden_dtype = kv_caches[0].dtype
input_dtype = torch.int32
# Add LLM DataDist initialization
key_desc = llm_datadist.CacheDesc(
num_layer,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=-1)
value_desc = llm_datadist.CacheDesc(
num_layer,
kv_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=-1)
input_desc = llm_datadist.CacheDesc(
1,
input_shape,
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
seq_len_dim_index=-1)
hidden_desc = llm_datadist.CacheDesc(
1,
hidden_shape,
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
seq_len_dim_index=-1)
self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
key_desc)
self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
value_desc)
self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
input_desc)
self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
hidden_desc)
key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0]
value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[
0]
input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[
0]
hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[
0]
self.key_cache = torchair.llm_datadist.create_npu_tensors(
key_desc.shape, kv_hidden_dtype, key_buffer_addrs)
self.value_cache = torchair.llm_datadist.create_npu_tensors(
value_desc.shape, kv_hidden_dtype, value_buffer_addrs)
self.input_cache = torchair.llm_datadist.create_npu_tensors(
input_desc.shape, input_dtype, input_buffer_addrs)
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs)
key_cache_key = llm_datadist.CacheKeyByIdAndIndex(
self.cluster.remote_cluster_id, 1, 0)
value_cache_key = llm_datadist.CacheKeyByIdAndIndex(
self.cluster.remote_cluster_id, 2, 0)
input_cache_key = llm_datadist.CacheKeyByIdAndIndex(
self.cluster.remote_cluster_id, 3, 0)
hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex(
self.cluster.remote_cluster_id, 4, 0)
self.llm_datadist_engine.kv_transfer.pull_cache(
key_cache_key, self.decode_key_buffer, 0)
self.llm_datadist_engine.kv_transfer.pull_cache(
value_cache_key, self.decode_value_buffer, 0)
self.llm_datadist_engine.kv_transfer.pull_cache(
input_cache_key, self.decode_input_buffer, 0)
self.llm_datadist_engine.kv_transfer.pull_cache(
hidden_cache_key, self.decode_hidden_buffer, 0)
keys = self.key_cache
values = self.value_cache
inputs = self.input_cache
hidden = self.hidden_cache
# enumerate different requests
for idx, slen in enumerate(seq_lens):
start_pos = sum(seq_lens[:idx])
end_pos = start_pos + slen
current_tokens = input_tokens_tensor[start_pos:end_pos]
num_tokens = slen
# collecting data for rebuilding the input
input_tokens_list.append(current_tokens)
start_pos_list.append(start_pos)
num_computed_tokens = inputs[0][0, start_pos:end_pos, 0,
0].shape[0]
num_computed_tokens_list.append(num_computed_tokens)
# check if both KV cache and the hidden states are received
# If not, need to redo the forwarding to compute missing states
if not all([(num_computed_tokens == num_tokens), hidden is not None
]):
bypass_model_exec = False
# update the end position based on how many tokens are cached.
end_pos = start_pos + num_computed_tokens
# put received KV caches into paged memory
for i in range(model_executable.model.start_layer,
model_executable.model.end_layer):
kv_cache = kv_caches[i - model_executable.model.start_layer]
key_cache, value_cache = kv_cache[0], kv_cache[1]
sliced_key = keys[i - model_executable.model.start_layer][
0, start_pos:end_pos, :, :]
sliced_value = values[i - model_executable.model.start_layer][
0, start_pos:end_pos, :, :]
torch_npu._npu_reshape_and_cache(
key=sliced_key,
value=sliced_value,
key_cache=key_cache,
value_cache=value_cache,
slot_indices=slot_mapping[start_pos:end_pos])
hidden_or_intermediate_states_for_one_req.append(
hidden[0][0, start_pos:end_pos, 0, :])
if not bypass_model_exec:
# Some of the KV cache is not retrieved
# Here we will fall back to normal model forwarding
# But optionally you can adjust model_input so that you only do
# prefilling on those tokens that are missing KV caches.
logger.info(
"[rank%d][D]: Failed to receive all KVs and hidden "
"states, redo model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = None
else:
logger.info(
"[rank%d][D]: Successfully received all KVs and hidden "
"states, skip model forwarding.", torch.distributed.get_rank())
hidden_or_intermediate_states = torch.cat(
hidden_or_intermediate_states_for_one_req, dim=0)
return hidden_or_intermediate_states, bypass_model_exec, model_input
def close(self, ):
self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster],
5000)

View File

@@ -18,6 +18,17 @@ env_variables: Dict[str, Callable[[], Any]] = {
lambda: os.getenv("ASCEND_HOME_PATH", None), lambda: os.getenv("ASCEND_HOME_PATH", None),
"LD_LIBRARY_PATH": "LD_LIBRARY_PATH":
lambda: os.getenv("LD_LIBRARY_PATH", None), lambda: os.getenv("LD_LIBRARY_PATH", None),
# Used for disaggregated prefilling
"HCCN_PATH":
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
"PROMPT_DEVICE_ID":
lambda: os.getenv("PROMPT_DEVICE_ID", None),
"DECODE_DEVICE_ID":
lambda: os.getenv("DECODE_DEVICE_ID", None),
"LLMDATADIST_COMM_PORT":
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000")
} }

View File

@@ -150,7 +150,7 @@ class NPUPlatform(Platform):
@classmethod @classmethod
def get_device_communicator_cls(cls) -> str: def get_device_communicator_cls(cls) -> str:
return "vllm_ascend.communicator.NPUCommunicator" return "vllm_ascend.distributed.communicator.NPUCommunicator"
@classmethod @classmethod
def is_pin_memory_available(cls): def is_pin_memory_available(cls):

View File

@@ -24,8 +24,9 @@ import torch
import torch.distributed import torch.distributed
from torch import nn from torch import nn
from vllm import envs from vllm import envs
from vllm.config import ParallelConfig, VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.logger import logger from vllm.logger import logger
@@ -161,8 +162,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
raise RuntimeError( raise RuntimeError(
f"Not support device type: {self.device_config.device}") f"Not support device type: {self.device_config.device}")
# Initialize the distributed environment. # Initialize the distributed environment.
self._init_worker_distributed_environment(self.parallel_config, self._init_worker_distributed_environment(self.vllm_config, self.rank,
self.rank,
self.distributed_init_method, self.distributed_init_method,
self.local_rank) self.local_rank)
# Set random seed. # Set random seed.
@@ -450,12 +450,13 @@ class NPUWorker(LocalOrDistributedWorkerBase):
def _init_worker_distributed_environment( def _init_worker_distributed_environment(
self, self,
parallel_config: ParallelConfig, vllm_config: VllmConfig,
rank: int, rank: int,
distributed_init_method: Optional[str] = None, distributed_init_method: Optional[str] = None,
local_rank: int = -1, local_rank: int = -1,
backend: str = "hccl") -> None: backend: str = "hccl") -> None:
"""Initialize the distributed environment.""" """Initialize the distributed environment."""
parallel_config = self.parallel_config
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
init_distributed_environment(parallel_config.world_size, rank, init_distributed_environment(parallel_config.world_size, rank,
distributed_init_method, local_rank, distributed_init_method, local_rank,
@@ -463,6 +464,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
parallel_config.tensor_parallel_size, parallel_config.tensor_parallel_size,
parallel_config.pipeline_parallel_size) parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(vllm_config)
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,

View File

@@ -25,7 +25,8 @@ import torch.nn as nn
import torch_npu import torch_npu
from vllm import envs from vllm import envs
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.distributed import (ensure_model_parallel_initialized, from vllm.distributed import (ensure_kv_transfer_initialized,
ensure_model_parallel_initialized,
init_distributed_environment, init_distributed_environment,
set_custom_all_reduce) set_custom_all_reduce)
from vllm.logger import logger from vllm.logger import logger
@@ -197,6 +198,7 @@ class NPUWorker(WorkerBase):
ensure_model_parallel_initialized( ensure_model_parallel_initialized(
self.parallel_config.tensor_parallel_size, self.parallel_config.tensor_parallel_size,
self.parallel_config.pipeline_parallel_size) self.parallel_config.pipeline_parallel_size)
ensure_kv_transfer_initialized(self.vllm_config)
def _init_profiler(self): def _init_profiler(self):
# Torch profiler. Enabled and configured through env vars: # Torch profiler. Enabled and configured through env vars:
@@ -230,4 +232,4 @@ class NPUWorker(WorkerBase):
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
torch_profiler_trace_dir)) torch_profiler_trace_dir))
else: else:
return None return None