[Feature] Add PD separation feature (#432)
### What this PR does / why we need it? Adapt Disaggregated Prefill feature onto Ascend device ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? The test usage has been provided alongwith the PR, in examples/offline_disaggregated_prefill_npu.py To run it, do this ``` export PROMPT_DEVICE_ID=0,1 export DECODE_DEVICE_ID=2,3 python examples/offline_disaggregated_prefill_npu.py ``` --------- Signed-off-by: ZihuiQian <qianzihui@huawei.com> Co-authored-by: ZihuiQian <qianzihui@huawei.com>
This commit is contained in:
140
examples/offline_disaggregated_prefill_npu.py
Normal file
140
examples/offline_disaggregated_prefill_npu.py
Normal file
@@ -0,0 +1,140 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/examples/offline_inference/basic.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Event, Process
|
||||
|
||||
|
||||
def clean_up():
|
||||
import gc
|
||||
|
||||
import torch
|
||||
from vllm.distributed.parallel_state import (
|
||||
destroy_distributed_environment, destroy_model_parallel)
|
||||
destroy_model_parallel()
|
||||
destroy_distributed_environment()
|
||||
gc.collect()
|
||||
torch.npu.empty_cache()
|
||||
|
||||
|
||||
def run_prefill(prefill_done, process_close):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?", "Hi, what is your name?",
|
||||
"Tell me a very long story.", "what is your favourite book?"
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}'
|
||||
)
|
||||
|
||||
# Set NPU memory utilization to 0.8
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
llm.generate(prompts, sampling_params)
|
||||
print("Prefill node is finished.")
|
||||
prefill_done.set()
|
||||
|
||||
# To keep the prefill node running in case the decode node is not done
|
||||
# otherwise, the script might exit prematurely, causing incomplete decoding.
|
||||
try:
|
||||
while not process_close.is_set():
|
||||
time.sleep(1)
|
||||
except KeyboardInterrupt:
|
||||
print("Script stopped by user.")
|
||||
finally:
|
||||
print("Cleanup prefill resources")
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
def run_decode(prefill_done):
|
||||
os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3"
|
||||
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import KVTransferConfig
|
||||
|
||||
prompts = [
|
||||
"Hello, how are you today?", "Hi, what is your name?",
|
||||
"Tell me a very long story.", "what is your favourite book?"
|
||||
]
|
||||
sampling_params = SamplingParams(temperature=0, top_p=0.95)
|
||||
|
||||
ktc = KVTransferConfig.from_cli(
|
||||
'{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}'
|
||||
)
|
||||
|
||||
llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B",
|
||||
kv_transfer_config=ktc,
|
||||
max_model_len=2000,
|
||||
gpu_memory_utilization=0.8,
|
||||
tensor_parallel_size=2)
|
||||
|
||||
# Wait for the producer to start the consumer
|
||||
print("Waiting for prefill node to finish...")
|
||||
prefill_done.wait()
|
||||
|
||||
# At this point when the prefill_done is set, the kv-cache should have been
|
||||
# transferred to this decode node, so we can start decoding.
|
||||
outputs = llm.generate(prompts, sampling_params)
|
||||
for output in outputs:
|
||||
prompt = output.prompt
|
||||
generated_text = output.outputs[0].text
|
||||
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
|
||||
|
||||
del llm
|
||||
clean_up()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
mp.get_context('spawn')
|
||||
|
||||
prefill_done = Event()
|
||||
process_close = Event()
|
||||
prefill_process = Process(target=run_prefill,
|
||||
args=(
|
||||
prefill_done,
|
||||
process_close,
|
||||
))
|
||||
decode_process = Process(target=run_decode, args=(prefill_done, ))
|
||||
|
||||
# Start prefill node
|
||||
prefill_process.start()
|
||||
|
||||
# Start decode node
|
||||
decode_process.start()
|
||||
|
||||
# Terminate the prefill node when decode is finished
|
||||
decode_process.join()
|
||||
|
||||
# Terminate prefill process
|
||||
process_close.set()
|
||||
prefill_process.join()
|
||||
prefill_process.terminate()
|
||||
print("All process done!")
|
||||
6
vllm_ascend/distributed/__init__.py
Normal file
6
vllm_ascend/distributed/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from vllm.distributed.kv_transfer.kv_connector.factory import \
|
||||
KVConnectorFactory
|
||||
|
||||
KVConnectorFactory.register_connector(
|
||||
"AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector",
|
||||
"LLMDataDistConnector")
|
||||
465
vllm_ascend/distributed/llmdatadist_connector.py
Normal file
465
vllm_ascend/distributed/llmdatadist_connector.py
Normal file
@@ -0,0 +1,465 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import os
|
||||
import re
|
||||
import subprocess
|
||||
from typing import TYPE_CHECKING, List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch_npu
|
||||
import torchair # type: ignore
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase
|
||||
from vllm.logger import logger
|
||||
from vllm.sequence import IntermediateTensors
|
||||
|
||||
import vllm_ascend.envs as envs
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata
|
||||
|
||||
import llm_datadist # type: ignore
|
||||
|
||||
TORCH_DTYPE_TO_NPU_DTYPE = {
|
||||
torch.half: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.float16: llm_datadist.DataType.DT_FLOAT16,
|
||||
torch.bfloat16: llm_datadist.DataType.DT_BF16,
|
||||
torch.float: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.float32: llm_datadist.DataType.DT_FLOAT,
|
||||
torch.int8: llm_datadist.DataType.DT_INT8,
|
||||
torch.int64: llm_datadist.DataType.DT_INT64,
|
||||
torch.int32: llm_datadist.DataType.DT_INT32
|
||||
}
|
||||
|
||||
# Get all device ips using hccn_tool
|
||||
HCCN_TOOL_PATH = envs.HCCN_PATH
|
||||
|
||||
|
||||
def get_device_ips():
|
||||
world_size = 8
|
||||
npu_info = subprocess.run(['npu-smi', 'info', '-m'],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH):
|
||||
raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.")
|
||||
npu_start_idx = int(
|
||||
re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1))
|
||||
device_ip_list = []
|
||||
for ip_offset in range(world_size):
|
||||
cmd = [
|
||||
HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g'
|
||||
]
|
||||
device_ip_info = subprocess.run(cmd,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
universal_newlines=True)
|
||||
device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1)
|
||||
device_ip_list.append(device_ip)
|
||||
return device_ip_list
|
||||
|
||||
|
||||
class KVTransferEngine:
|
||||
|
||||
def __init__(self, world_size, n_layer, role, local_rank):
|
||||
self.world_size = world_size
|
||||
self.n_layer = n_layer
|
||||
self.role = role
|
||||
self.device_ip_list = get_device_ips()
|
||||
self.local_rank = local_rank
|
||||
self.cluster_id = local_rank
|
||||
self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id)
|
||||
|
||||
prompt_device_ids = envs.PROMPT_DEVICE_ID
|
||||
decode_device_ids = envs.DECODE_DEVICE_ID
|
||||
if prompt_device_ids is None or decode_device_ids is None:
|
||||
raise ValueError(
|
||||
"Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID")
|
||||
|
||||
prompt_ids = [
|
||||
int(x.strip()) for x in prompt_device_ids.split(",") if x.strip()
|
||||
]
|
||||
decode_ids = [
|
||||
int(x.strip()) for x in decode_device_ids.split(",") if x.strip()
|
||||
]
|
||||
|
||||
self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids]
|
||||
self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids]
|
||||
|
||||
def prepare_data_dist(self):
|
||||
options = {
|
||||
"llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME,
|
||||
}
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
options[
|
||||
"llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}"
|
||||
else:
|
||||
options["ge.exec.deviceId"] = str(self.local_rank)
|
||||
self.data_dist.init(options)
|
||||
self.kv_transfer = self.data_dist.kv_cache_manager
|
||||
logger.info(
|
||||
f"{self.local_rank}/{self.world_size} rank data dist is ready")
|
||||
|
||||
def make_cluster(self, prefill_ip, cluster_id=-1):
|
||||
cluster = llm_datadist.LLMClusterInfo()
|
||||
cluster.remote_cluster_id = cluster_id
|
||||
local_ip = self.decode_ip_list[self.local_rank]
|
||||
remote_ip = prefill_ip
|
||||
cluster.append_local_ip_info(local_ip, 0)
|
||||
cluster.append_remote_ip_info(remote_ip, 26000)
|
||||
return cluster
|
||||
|
||||
|
||||
class LLMDataDistConnector(KVConnectorBase):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rank: int,
|
||||
local_rank: int,
|
||||
config: VllmConfig,
|
||||
):
|
||||
self.config = config
|
||||
self.tp_size = config.parallel_config.tensor_parallel_size
|
||||
self.rank = rank
|
||||
self.local_rank = local_rank
|
||||
|
||||
if self.config.kv_transfer_config.kv_role == "kv_producer":
|
||||
self.role = llm_datadist.LLMRole.PROMPT
|
||||
elif self.config.kv_transfer_config.kv_role == "kv_consumer":
|
||||
self.role = llm_datadist.LLMRole.DECODER
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"kv_role should be inside [kv_producer, kv_consumer]")
|
||||
|
||||
self.world_size = self.config.parallel_config.world_size
|
||||
self.n_layer = self.config.model_config.get_num_layers(
|
||||
self.config.parallel_config)
|
||||
|
||||
self.llm_datadist_engine = KVTransferEngine(self.world_size,
|
||||
self.n_layer, self.role,
|
||||
self.local_rank)
|
||||
if self.role == llm_datadist.LLMRole.PROMPT:
|
||||
self.llm_datadist_engine.prepare_data_dist()
|
||||
else:
|
||||
self.llm_datadist_engine.prepare_data_dist()
|
||||
self.cluster = self.llm_datadist_engine.make_cluster(
|
||||
self.llm_datadist_engine.prompt_ip_list[self.local_rank],
|
||||
self.llm_datadist_engine.cluster_id)
|
||||
_, ret = self.llm_datadist_engine.data_dist.link_clusters(
|
||||
[self.cluster], 20000)
|
||||
logger.info(f"local_rank {self.local_rank} link, ret={ret}")
|
||||
|
||||
def send_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor],
|
||||
hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors]
|
||||
) -> None:
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten()
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
|
||||
num_layer = end_layer - start_layer
|
||||
|
||||
# Get shape of input_tokens_tensor and kv_cache
|
||||
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
|
||||
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
|
||||
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
|
||||
|
||||
assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype
|
||||
kv_hidden_dtype = kv_caches[0].dtype
|
||||
input_dtype = torch.int32
|
||||
|
||||
# initialize LLMDatadist data structure
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=1)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=1)
|
||||
input_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
input_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
|
||||
key_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1)
|
||||
]
|
||||
value_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2)
|
||||
]
|
||||
input_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3)
|
||||
]
|
||||
hidden_cache_keys = [
|
||||
llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4)
|
||||
]
|
||||
|
||||
self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
key_desc, key_cache_keys)
|
||||
self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
value_desc, value_cache_keys)
|
||||
self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
input_desc, input_cache_keys)
|
||||
self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
hidden_desc, hidden_cache_keys)
|
||||
|
||||
key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0]
|
||||
value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0]
|
||||
input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0]
|
||||
hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0]
|
||||
|
||||
self.key_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
key_desc.shape, kv_hidden_dtype, key_buffer_addr)
|
||||
self.value_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
value_desc.shape, kv_hidden_dtype, value_buffer_addr)
|
||||
self.input_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
input_desc.shape, input_dtype, input_buffer_addr)
|
||||
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr)
|
||||
|
||||
indices = torch.tensor([0], dtype=torch.int64).npu()
|
||||
|
||||
# copy cache data into llm datadist cache using scatter update
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos].to(
|
||||
torch.int32)
|
||||
|
||||
for layer_id in range(start_layer, end_layer):
|
||||
kv_cache = kv_caches[layer_id - start_layer]
|
||||
|
||||
key_cache = kv_cache[0].view(-1, num_heads, head_size)
|
||||
value_cache = kv_cache[1].view(-1, num_heads, head_size)
|
||||
|
||||
current_slot_mapping = slot_mapping_flat[start_pos:end_pos]
|
||||
|
||||
# copy key into datadist
|
||||
k = self.key_cache[layer_id][:, start_pos:end_pos, :, :]
|
||||
new_k = key_cache[current_slot_mapping].unsqueeze(0)
|
||||
torch_npu.scatter_update_(k, indices, new_k, axis=-2)
|
||||
|
||||
# copy value into datadist
|
||||
val = self.value_cache[layer_id][:, start_pos:end_pos, :, :]
|
||||
new_val = value_cache[current_slot_mapping].unsqueeze(0)
|
||||
torch_npu.scatter_update_(val, indices, new_val, axis=-2)
|
||||
|
||||
# copy input into datadist
|
||||
inp = self.input_cache[0][:, start_pos:end_pos, :, :]
|
||||
new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1)
|
||||
torch_npu.scatter_update_(inp, indices, new_inp, axis=-2)
|
||||
|
||||
# copy hidden into datadist
|
||||
hid = self.hidden_cache[0][:, start_pos:end_pos, :, :]
|
||||
hid_shape0, hid_shape1 = hidden_or_intermediate_states[
|
||||
start_pos:end_pos].shape
|
||||
new_hid = hidden_or_intermediate_states[start_pos:end_pos].view(
|
||||
1, hid_shape0, 1, hid_shape1)
|
||||
torch_npu.scatter_update_(hid, indices, new_hid, axis=-2)
|
||||
|
||||
logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank())
|
||||
|
||||
def recv_kv_caches_and_hidden_states(
|
||||
self, model_executable: torch.nn.Module,
|
||||
model_input: "ModelInputForGPUWithSamplingMetadata",
|
||||
kv_caches: List[torch.Tensor]
|
||||
) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool,
|
||||
"ModelInputForGPUWithSamplingMetadata"]:
|
||||
bypass_model_exec = True
|
||||
|
||||
input_tokens_tensor = model_input.input_tokens
|
||||
seq_lens = model_input.attn_metadata.seq_lens
|
||||
slot_mapping = model_input.attn_metadata.slot_mapping.flatten()
|
||||
|
||||
hidden_or_intermediate_states_for_one_req = []
|
||||
|
||||
input_tokens_list = []
|
||||
num_computed_tokens_list = []
|
||||
start_pos_list = []
|
||||
|
||||
# get model config
|
||||
start_layer = model_executable.model.start_layer
|
||||
end_layer = model_executable.model.end_layer
|
||||
model_config = model_executable.model.config
|
||||
num_heads = int(model_config.num_key_value_heads / self.tp_size)
|
||||
hidden_size = model_config.hidden_size
|
||||
num_attention_heads = model_config.num_attention_heads
|
||||
head_size = int(hidden_size / num_attention_heads)
|
||||
num_layer = end_layer - start_layer
|
||||
|
||||
# get input_tensor_shape and hidden_shape
|
||||
input_shape = (1, input_tokens_tensor.shape[0], 1, 1)
|
||||
hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size)
|
||||
kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size)
|
||||
|
||||
kv_hidden_dtype = kv_caches[0].dtype
|
||||
input_dtype = torch.int32
|
||||
|
||||
# Add LLM DataDist initialization
|
||||
key_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
value_desc = llm_datadist.CacheDesc(
|
||||
num_layer,
|
||||
kv_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
input_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
input_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[input_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
hidden_desc = llm_datadist.CacheDesc(
|
||||
1,
|
||||
hidden_shape,
|
||||
TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype],
|
||||
seq_len_dim_index=-1)
|
||||
self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
key_desc)
|
||||
self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
value_desc)
|
||||
self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
input_desc)
|
||||
self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache(
|
||||
hidden_desc)
|
||||
key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0]
|
||||
value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[
|
||||
0]
|
||||
self.key_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
key_desc.shape, kv_hidden_dtype, key_buffer_addrs)
|
||||
self.value_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
value_desc.shape, kv_hidden_dtype, value_buffer_addrs)
|
||||
self.input_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
input_desc.shape, input_dtype, input_buffer_addrs)
|
||||
self.hidden_cache = torchair.llm_datadist.create_npu_tensors(
|
||||
hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs)
|
||||
|
||||
key_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 1, 0)
|
||||
value_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 2, 0)
|
||||
input_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 3, 0)
|
||||
hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex(
|
||||
self.cluster.remote_cluster_id, 4, 0)
|
||||
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
key_cache_key, self.decode_key_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
value_cache_key, self.decode_value_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
input_cache_key, self.decode_input_buffer, 0)
|
||||
self.llm_datadist_engine.kv_transfer.pull_cache(
|
||||
hidden_cache_key, self.decode_hidden_buffer, 0)
|
||||
|
||||
keys = self.key_cache
|
||||
values = self.value_cache
|
||||
inputs = self.input_cache
|
||||
hidden = self.hidden_cache
|
||||
|
||||
# enumerate different requests
|
||||
for idx, slen in enumerate(seq_lens):
|
||||
start_pos = sum(seq_lens[:idx])
|
||||
end_pos = start_pos + slen
|
||||
current_tokens = input_tokens_tensor[start_pos:end_pos]
|
||||
num_tokens = slen
|
||||
|
||||
# collecting data for rebuilding the input
|
||||
input_tokens_list.append(current_tokens)
|
||||
start_pos_list.append(start_pos)
|
||||
|
||||
num_computed_tokens = inputs[0][0, start_pos:end_pos, 0,
|
||||
0].shape[0]
|
||||
num_computed_tokens_list.append(num_computed_tokens)
|
||||
|
||||
# check if both KV cache and the hidden states are received
|
||||
# If not, need to redo the forwarding to compute missing states
|
||||
if not all([(num_computed_tokens == num_tokens), hidden is not None
|
||||
]):
|
||||
bypass_model_exec = False
|
||||
|
||||
# update the end position based on how many tokens are cached.
|
||||
end_pos = start_pos + num_computed_tokens
|
||||
|
||||
# put received KV caches into paged memory
|
||||
for i in range(model_executable.model.start_layer,
|
||||
model_executable.model.end_layer):
|
||||
kv_cache = kv_caches[i - model_executable.model.start_layer]
|
||||
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
||||
|
||||
sliced_key = keys[i - model_executable.model.start_layer][
|
||||
0, start_pos:end_pos, :, :]
|
||||
sliced_value = values[i - model_executable.model.start_layer][
|
||||
0, start_pos:end_pos, :, :]
|
||||
|
||||
torch_npu._npu_reshape_and_cache(
|
||||
key=sliced_key,
|
||||
value=sliced_value,
|
||||
key_cache=key_cache,
|
||||
value_cache=value_cache,
|
||||
slot_indices=slot_mapping[start_pos:end_pos])
|
||||
|
||||
hidden_or_intermediate_states_for_one_req.append(
|
||||
hidden[0][0, start_pos:end_pos, 0, :])
|
||||
|
||||
if not bypass_model_exec:
|
||||
# Some of the KV cache is not retrieved
|
||||
# Here we will fall back to normal model forwarding
|
||||
# But optionally you can adjust model_input so that you only do
|
||||
# prefilling on those tokens that are missing KV caches.
|
||||
logger.info(
|
||||
"[rank%d][D]: Failed to receive all KVs and hidden "
|
||||
"states, redo model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = None
|
||||
else:
|
||||
logger.info(
|
||||
"[rank%d][D]: Successfully received all KVs and hidden "
|
||||
"states, skip model forwarding.", torch.distributed.get_rank())
|
||||
hidden_or_intermediate_states = torch.cat(
|
||||
hidden_or_intermediate_states_for_one_req, dim=0)
|
||||
|
||||
return hidden_or_intermediate_states, bypass_model_exec, model_input
|
||||
|
||||
def close(self, ):
|
||||
self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster],
|
||||
5000)
|
||||
@@ -18,6 +18,17 @@ env_variables: Dict[str, Callable[[], Any]] = {
|
||||
lambda: os.getenv("ASCEND_HOME_PATH", None),
|
||||
"LD_LIBRARY_PATH":
|
||||
lambda: os.getenv("LD_LIBRARY_PATH", None),
|
||||
# Used for disaggregated prefilling
|
||||
"HCCN_PATH":
|
||||
lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"),
|
||||
"PROMPT_DEVICE_ID":
|
||||
lambda: os.getenv("PROMPT_DEVICE_ID", None),
|
||||
"DECODE_DEVICE_ID":
|
||||
lambda: os.getenv("DECODE_DEVICE_ID", None),
|
||||
"LLMDATADIST_COMM_PORT":
|
||||
lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"),
|
||||
"LLMDATADIST_SYNC_CACHE_WAIT_TIME":
|
||||
lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000")
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -150,7 +150,7 @@ class NPUPlatform(Platform):
|
||||
|
||||
@classmethod
|
||||
def get_device_communicator_cls(cls) -> str:
|
||||
return "vllm_ascend.communicator.NPUCommunicator"
|
||||
return "vllm_ascend.distributed.communicator.NPUCommunicator"
|
||||
|
||||
@classmethod
|
||||
def is_pin_memory_available(cls):
|
||||
|
||||
@@ -24,8 +24,9 @@ import torch
|
||||
import torch.distributed
|
||||
from torch import nn
|
||||
from vllm import envs
|
||||
from vllm.config import ParallelConfig, VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import logger
|
||||
@@ -161,8 +162,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
raise RuntimeError(
|
||||
f"Not support device type: {self.device_config.device}")
|
||||
# Initialize the distributed environment.
|
||||
self._init_worker_distributed_environment(self.parallel_config,
|
||||
self.rank,
|
||||
self._init_worker_distributed_environment(self.vllm_config, self.rank,
|
||||
self.distributed_init_method,
|
||||
self.local_rank)
|
||||
# Set random seed.
|
||||
@@ -450,12 +450,13 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
|
||||
def _init_worker_distributed_environment(
|
||||
self,
|
||||
parallel_config: ParallelConfig,
|
||||
vllm_config: VllmConfig,
|
||||
rank: int,
|
||||
distributed_init_method: Optional[str] = None,
|
||||
local_rank: int = -1,
|
||||
backend: str = "hccl") -> None:
|
||||
"""Initialize the distributed environment."""
|
||||
parallel_config = self.parallel_config
|
||||
set_custom_all_reduce(not parallel_config.disable_custom_all_reduce)
|
||||
init_distributed_environment(parallel_config.world_size, rank,
|
||||
distributed_init_method, local_rank,
|
||||
@@ -463,6 +464,7 @@ class NPUWorker(LocalOrDistributedWorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
parallel_config.tensor_parallel_size,
|
||||
parallel_config.pipeline_parallel_size)
|
||||
ensure_kv_transfer_initialized(vllm_config)
|
||||
|
||||
|
||||
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free,
|
||||
|
||||
@@ -25,7 +25,8 @@ import torch.nn as nn
|
||||
import torch_npu
|
||||
from vllm import envs
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import (ensure_model_parallel_initialized,
|
||||
from vllm.distributed import (ensure_kv_transfer_initialized,
|
||||
ensure_model_parallel_initialized,
|
||||
init_distributed_environment,
|
||||
set_custom_all_reduce)
|
||||
from vllm.logger import logger
|
||||
@@ -197,6 +198,7 @@ class NPUWorker(WorkerBase):
|
||||
ensure_model_parallel_initialized(
|
||||
self.parallel_config.tensor_parallel_size,
|
||||
self.parallel_config.pipeline_parallel_size)
|
||||
ensure_kv_transfer_initialized(self.vllm_config)
|
||||
|
||||
def _init_profiler(self):
|
||||
# Torch profiler. Enabled and configured through env vars:
|
||||
@@ -230,4 +232,4 @@ class NPUWorker(WorkerBase):
|
||||
on_trace_ready=torch_npu.profiler.tensorboard_trace_handler(
|
||||
torch_profiler_trace_dir))
|
||||
else:
|
||||
return None
|
||||
return None
|
||||
Reference in New Issue
Block a user