diff --git a/examples/offline_disaggregated_prefill_npu.py b/examples/offline_disaggregated_prefill_npu.py new file mode 100644 index 0000000..9cea63a --- /dev/null +++ b/examples/offline_disaggregated_prefill_npu.py @@ -0,0 +1,140 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# Adapted from vllm-project/vllm/examples/offline_inference/basic.py +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import multiprocessing as mp +import os +import time +from multiprocessing import Event, Process + + +def clean_up(): + import gc + + import torch + from vllm.distributed.parallel_state import ( + destroy_distributed_environment, destroy_model_parallel) + destroy_model_parallel() + destroy_distributed_environment() + gc.collect() + torch.npu.empty_cache() + + +def run_prefill(prefill_done, process_close): + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "0,1" + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + prompts = [ + "Hello, how are you today?", "Hi, what is your name?", + "Tell me a very long story.", "what is your favourite book?" + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=1) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_producer", "kv_parallel_size":2}' + ) + + # Set NPU memory utilization to 0.8 + llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2) + + llm.generate(prompts, sampling_params) + print("Prefill node is finished.") + prefill_done.set() + + # To keep the prefill node running in case the decode node is not done + # otherwise, the script might exit prematurely, causing incomplete decoding. + try: + while not process_close.is_set(): + time.sleep(1) + except KeyboardInterrupt: + print("Script stopped by user.") + finally: + print("Cleanup prefill resources") + del llm + clean_up() + + +def run_decode(prefill_done): + os.environ["ASCEND_RT_VISIBLE_DEVICES"] = "2,3" + + from vllm import LLM, SamplingParams + from vllm.config import KVTransferConfig + + prompts = [ + "Hello, how are you today?", "Hi, what is your name?", + "Tell me a very long story.", "what is your favourite book?" + ] + sampling_params = SamplingParams(temperature=0, top_p=0.95) + + ktc = KVTransferConfig.from_cli( + '{"kv_connector":"AscendHcclConnector","kv_buffer_device":"npu","kv_role":"kv_consumer","kv_parallel_size":2}' + ) + + llm = LLM(model="deepseek-ai/DeepSeek-R1-Distill-Qwen-1.5B", + kv_transfer_config=ktc, + max_model_len=2000, + gpu_memory_utilization=0.8, + tensor_parallel_size=2) + + # Wait for the producer to start the consumer + print("Waiting for prefill node to finish...") + prefill_done.wait() + + # At this point when the prefill_done is set, the kv-cache should have been + # transferred to this decode node, so we can start decoding. + outputs = llm.generate(prompts, sampling_params) + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + del llm + clean_up() + + +if __name__ == "__main__": + mp.get_context('spawn') + + prefill_done = Event() + process_close = Event() + prefill_process = Process(target=run_prefill, + args=( + prefill_done, + process_close, + )) + decode_process = Process(target=run_decode, args=(prefill_done, )) + + # Start prefill node + prefill_process.start() + + # Start decode node + decode_process.start() + + # Terminate the prefill node when decode is finished + decode_process.join() + + # Terminate prefill process + process_close.set() + prefill_process.join() + prefill_process.terminate() + print("All process done!") diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py new file mode 100644 index 0000000..2b2fd2c --- /dev/null +++ b/vllm_ascend/distributed/__init__.py @@ -0,0 +1,6 @@ +from vllm.distributed.kv_transfer.kv_connector.factory import \ + KVConnectorFactory + +KVConnectorFactory.register_connector( + "AscendHcclConnector", "vllm_ascend.distributed.llmdatadist_connector", + "LLMDataDistConnector") diff --git a/vllm_ascend/communicator.py b/vllm_ascend/distributed/communicator.py similarity index 100% rename from vllm_ascend/communicator.py rename to vllm_ascend/distributed/communicator.py diff --git a/vllm_ascend/distributed/llmdatadist_connector.py b/vllm_ascend/distributed/llmdatadist_connector.py new file mode 100644 index 0000000..6e0d4e5 --- /dev/null +++ b/vllm_ascend/distributed/llmdatadist_connector.py @@ -0,0 +1,465 @@ +# +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# This file is a part of the vllm-ascend project. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import re +import subprocess +from typing import TYPE_CHECKING, List, Tuple, Union + +import torch +import torch_npu +import torchair # type: ignore +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase +from vllm.logger import logger +from vllm.sequence import IntermediateTensors + +import vllm_ascend.envs as envs + +if TYPE_CHECKING: + from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata + +import llm_datadist # type: ignore + +TORCH_DTYPE_TO_NPU_DTYPE = { + torch.half: llm_datadist.DataType.DT_FLOAT16, + torch.float16: llm_datadist.DataType.DT_FLOAT16, + torch.bfloat16: llm_datadist.DataType.DT_BF16, + torch.float: llm_datadist.DataType.DT_FLOAT, + torch.float32: llm_datadist.DataType.DT_FLOAT, + torch.int8: llm_datadist.DataType.DT_INT8, + torch.int64: llm_datadist.DataType.DT_INT64, + torch.int32: llm_datadist.DataType.DT_INT32 +} + +# Get all device ips using hccn_tool +HCCN_TOOL_PATH = envs.HCCN_PATH + + +def get_device_ips(): + world_size = 8 + npu_info = subprocess.run(['npu-smi', 'info', '-m'], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) + if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): + raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") + npu_start_idx = int( + re.match(r'.*\n\t([0-9]+).*', npu_info.stdout).group(1)) + device_ip_list = [] + for ip_offset in range(world_size): + cmd = [ + HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g' + ] + device_ip_info = subprocess.run(cmd, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + universal_newlines=True) + device_ip = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout).group(1) + device_ip_list.append(device_ip) + return device_ip_list + + +class KVTransferEngine: + + def __init__(self, world_size, n_layer, role, local_rank): + self.world_size = world_size + self.n_layer = n_layer + self.role = role + self.device_ip_list = get_device_ips() + self.local_rank = local_rank + self.cluster_id = local_rank + self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id) + + prompt_device_ids = envs.PROMPT_DEVICE_ID + decode_device_ids = envs.DECODE_DEVICE_ID + if prompt_device_ids is None or decode_device_ids is None: + raise ValueError( + "Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID") + + prompt_ids = [ + int(x.strip()) for x in prompt_device_ids.split(",") if x.strip() + ] + decode_ids = [ + int(x.strip()) for x in decode_device_ids.split(",") if x.strip() + ] + + self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids] + self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids] + + def prepare_data_dist(self): + options = { + "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, + } + if self.role == llm_datadist.LLMRole.PROMPT: + options["ge.exec.deviceId"] = str(self.local_rank) + options[ + "llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}" + else: + options["ge.exec.deviceId"] = str(self.local_rank) + self.data_dist.init(options) + self.kv_transfer = self.data_dist.kv_cache_manager + logger.info( + f"{self.local_rank}/{self.world_size} rank data dist is ready") + + def make_cluster(self, prefill_ip, cluster_id=-1): + cluster = llm_datadist.LLMClusterInfo() + cluster.remote_cluster_id = cluster_id + local_ip = self.decode_ip_list[self.local_rank] + remote_ip = prefill_ip + cluster.append_local_ip_info(local_ip, 0) + cluster.append_remote_ip_info(remote_ip, 26000) + return cluster + + +class LLMDataDistConnector(KVConnectorBase): + + def __init__( + self, + rank: int, + local_rank: int, + config: VllmConfig, + ): + self.config = config + self.tp_size = config.parallel_config.tensor_parallel_size + self.rank = rank + self.local_rank = local_rank + + if self.config.kv_transfer_config.kv_role == "kv_producer": + self.role = llm_datadist.LLMRole.PROMPT + elif self.config.kv_transfer_config.kv_role == "kv_consumer": + self.role = llm_datadist.LLMRole.DECODER + else: + raise NotImplementedError( + "kv_role should be inside [kv_producer, kv_consumer]") + + self.world_size = self.config.parallel_config.world_size + self.n_layer = self.config.model_config.get_num_layers( + self.config.parallel_config) + + self.llm_datadist_engine = KVTransferEngine(self.world_size, + self.n_layer, self.role, + self.local_rank) + if self.role == llm_datadist.LLMRole.PROMPT: + self.llm_datadist_engine.prepare_data_dist() + else: + self.llm_datadist_engine.prepare_data_dist() + self.cluster = self.llm_datadist_engine.make_cluster( + self.llm_datadist_engine.prompt_ip_list[self.local_rank], + self.llm_datadist_engine.cluster_id) + _, ret = self.llm_datadist_engine.data_dist.link_clusters( + [self.cluster], 20000) + logger.info(f"local_rank {self.local_rank} link, ret={ret}") + + def send_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor], + hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors] + ) -> None: + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + + num_layer = end_layer - start_layer + + # Get shape of input_tokens_tensor and kv_cache + input_shape = (1, input_tokens_tensor.shape[0], 1, 1) + hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) + kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) + + assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype + kv_hidden_dtype = kv_caches[0].dtype + input_dtype = torch.int32 + + # initialize LLMDatadist data structure + key_desc = llm_datadist.CacheDesc( + num_layer, + kv_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=1) + value_desc = llm_datadist.CacheDesc( + num_layer, + kv_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=1) + input_desc = llm_datadist.CacheDesc( + 1, + input_shape, + TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], + seq_len_dim_index=-1) + hidden_desc = llm_datadist.CacheDesc( + 1, + hidden_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=-1) + + key_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1) + ] + value_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2) + ] + input_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3) + ] + hidden_cache_keys = [ + llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4) + ] + + self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + key_desc, key_cache_keys) + self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + value_desc, value_cache_keys) + self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + input_desc, input_cache_keys) + self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + hidden_desc, hidden_cache_keys) + + key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0] + value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0] + input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0] + hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0] + + self.key_cache = torchair.llm_datadist.create_npu_tensors( + key_desc.shape, kv_hidden_dtype, key_buffer_addr) + self.value_cache = torchair.llm_datadist.create_npu_tensors( + value_desc.shape, kv_hidden_dtype, value_buffer_addr) + self.input_cache = torchair.llm_datadist.create_npu_tensors( + input_desc.shape, input_dtype, input_buffer_addr) + self.hidden_cache = torchair.llm_datadist.create_npu_tensors( + hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr) + + indices = torch.tensor([0], dtype=torch.int64).npu() + + # copy cache data into llm datadist cache using scatter update + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos].to( + torch.int32) + + for layer_id in range(start_layer, end_layer): + kv_cache = kv_caches[layer_id - start_layer] + + key_cache = kv_cache[0].view(-1, num_heads, head_size) + value_cache = kv_cache[1].view(-1, num_heads, head_size) + + current_slot_mapping = slot_mapping_flat[start_pos:end_pos] + + # copy key into datadist + k = self.key_cache[layer_id][:, start_pos:end_pos, :, :] + new_k = key_cache[current_slot_mapping].unsqueeze(0) + torch_npu.scatter_update_(k, indices, new_k, axis=-2) + + # copy value into datadist + val = self.value_cache[layer_id][:, start_pos:end_pos, :, :] + new_val = value_cache[current_slot_mapping].unsqueeze(0) + torch_npu.scatter_update_(val, indices, new_val, axis=-2) + + # copy input into datadist + inp = self.input_cache[0][:, start_pos:end_pos, :, :] + new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1) + torch_npu.scatter_update_(inp, indices, new_inp, axis=-2) + + # copy hidden into datadist + hid = self.hidden_cache[0][:, start_pos:end_pos, :, :] + hid_shape0, hid_shape1 = hidden_or_intermediate_states[ + start_pos:end_pos].shape + new_hid = hidden_or_intermediate_states[start_pos:end_pos].view( + 1, hid_shape0, 1, hid_shape1) + torch_npu.scatter_update_(hid, indices, new_hid, axis=-2) + + logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) + + def recv_kv_caches_and_hidden_states( + self, model_executable: torch.nn.Module, + model_input: "ModelInputForGPUWithSamplingMetadata", + kv_caches: List[torch.Tensor] + ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, + "ModelInputForGPUWithSamplingMetadata"]: + bypass_model_exec = True + + input_tokens_tensor = model_input.input_tokens + seq_lens = model_input.attn_metadata.seq_lens + slot_mapping = model_input.attn_metadata.slot_mapping.flatten() + + hidden_or_intermediate_states_for_one_req = [] + + input_tokens_list = [] + num_computed_tokens_list = [] + start_pos_list = [] + + # get model config + start_layer = model_executable.model.start_layer + end_layer = model_executable.model.end_layer + model_config = model_executable.model.config + num_heads = int(model_config.num_key_value_heads / self.tp_size) + hidden_size = model_config.hidden_size + num_attention_heads = model_config.num_attention_heads + head_size = int(hidden_size / num_attention_heads) + num_layer = end_layer - start_layer + + # get input_tensor_shape and hidden_shape + input_shape = (1, input_tokens_tensor.shape[0], 1, 1) + hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) + kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) + + kv_hidden_dtype = kv_caches[0].dtype + input_dtype = torch.int32 + + # Add LLM DataDist initialization + key_desc = llm_datadist.CacheDesc( + num_layer, + kv_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=-1) + value_desc = llm_datadist.CacheDesc( + num_layer, + kv_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=-1) + input_desc = llm_datadist.CacheDesc( + 1, + input_shape, + TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], + seq_len_dim_index=-1) + hidden_desc = llm_datadist.CacheDesc( + 1, + hidden_shape, + TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], + seq_len_dim_index=-1) + self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + key_desc) + self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + value_desc) + self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + input_desc) + self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( + hidden_desc) + key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0] + value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[ + 0] + input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[ + 0] + hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[ + 0] + self.key_cache = torchair.llm_datadist.create_npu_tensors( + key_desc.shape, kv_hidden_dtype, key_buffer_addrs) + self.value_cache = torchair.llm_datadist.create_npu_tensors( + value_desc.shape, kv_hidden_dtype, value_buffer_addrs) + self.input_cache = torchair.llm_datadist.create_npu_tensors( + input_desc.shape, input_dtype, input_buffer_addrs) + self.hidden_cache = torchair.llm_datadist.create_npu_tensors( + hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs) + + key_cache_key = llm_datadist.CacheKeyByIdAndIndex( + self.cluster.remote_cluster_id, 1, 0) + value_cache_key = llm_datadist.CacheKeyByIdAndIndex( + self.cluster.remote_cluster_id, 2, 0) + input_cache_key = llm_datadist.CacheKeyByIdAndIndex( + self.cluster.remote_cluster_id, 3, 0) + hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex( + self.cluster.remote_cluster_id, 4, 0) + + self.llm_datadist_engine.kv_transfer.pull_cache( + key_cache_key, self.decode_key_buffer, 0) + self.llm_datadist_engine.kv_transfer.pull_cache( + value_cache_key, self.decode_value_buffer, 0) + self.llm_datadist_engine.kv_transfer.pull_cache( + input_cache_key, self.decode_input_buffer, 0) + self.llm_datadist_engine.kv_transfer.pull_cache( + hidden_cache_key, self.decode_hidden_buffer, 0) + + keys = self.key_cache + values = self.value_cache + inputs = self.input_cache + hidden = self.hidden_cache + + # enumerate different requests + for idx, slen in enumerate(seq_lens): + start_pos = sum(seq_lens[:idx]) + end_pos = start_pos + slen + current_tokens = input_tokens_tensor[start_pos:end_pos] + num_tokens = slen + + # collecting data for rebuilding the input + input_tokens_list.append(current_tokens) + start_pos_list.append(start_pos) + + num_computed_tokens = inputs[0][0, start_pos:end_pos, 0, + 0].shape[0] + num_computed_tokens_list.append(num_computed_tokens) + + # check if both KV cache and the hidden states are received + # If not, need to redo the forwarding to compute missing states + if not all([(num_computed_tokens == num_tokens), hidden is not None + ]): + bypass_model_exec = False + + # update the end position based on how many tokens are cached. + end_pos = start_pos + num_computed_tokens + + # put received KV caches into paged memory + for i in range(model_executable.model.start_layer, + model_executable.model.end_layer): + kv_cache = kv_caches[i - model_executable.model.start_layer] + key_cache, value_cache = kv_cache[0], kv_cache[1] + + sliced_key = keys[i - model_executable.model.start_layer][ + 0, start_pos:end_pos, :, :] + sliced_value = values[i - model_executable.model.start_layer][ + 0, start_pos:end_pos, :, :] + + torch_npu._npu_reshape_and_cache( + key=sliced_key, + value=sliced_value, + key_cache=key_cache, + value_cache=value_cache, + slot_indices=slot_mapping[start_pos:end_pos]) + + hidden_or_intermediate_states_for_one_req.append( + hidden[0][0, start_pos:end_pos, 0, :]) + + if not bypass_model_exec: + # Some of the KV cache is not retrieved + # Here we will fall back to normal model forwarding + # But optionally you can adjust model_input so that you only do + # prefilling on those tokens that are missing KV caches. + logger.info( + "[rank%d][D]: Failed to receive all KVs and hidden " + "states, redo model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = None + else: + logger.info( + "[rank%d][D]: Successfully received all KVs and hidden " + "states, skip model forwarding.", torch.distributed.get_rank()) + hidden_or_intermediate_states = torch.cat( + hidden_or_intermediate_states_for_one_req, dim=0) + + return hidden_or_intermediate_states, bypass_model_exec, model_input + + def close(self, ): + self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster], + 5000) \ No newline at end of file diff --git a/vllm_ascend/envs.py b/vllm_ascend/envs.py index fb88c3b..eb4f5f9 100644 --- a/vllm_ascend/envs.py +++ b/vllm_ascend/envs.py @@ -18,6 +18,17 @@ env_variables: Dict[str, Callable[[], Any]] = { lambda: os.getenv("ASCEND_HOME_PATH", None), "LD_LIBRARY_PATH": lambda: os.getenv("LD_LIBRARY_PATH", None), + # Used for disaggregated prefilling + "HCCN_PATH": + lambda: os.getenv("HCCN_PATH", "/usr/local/Ascend/driver/tools/hccn_tool"), + "PROMPT_DEVICE_ID": + lambda: os.getenv("PROMPT_DEVICE_ID", None), + "DECODE_DEVICE_ID": + lambda: os.getenv("DECODE_DEVICE_ID", None), + "LLMDATADIST_COMM_PORT": + lambda: os.getenv("LLMDATADIST_COMM_PORT", "26000"), + "LLMDATADIST_SYNC_CACHE_WAIT_TIME": + lambda: os.getenv("LLMDATADIST_SYNC_CACHE_WAIT_TIME", "5000") } diff --git a/vllm_ascend/platform.py b/vllm_ascend/platform.py index 7f26bfd..c434e19 100644 --- a/vllm_ascend/platform.py +++ b/vllm_ascend/platform.py @@ -150,7 +150,7 @@ class NPUPlatform(Platform): @classmethod def get_device_communicator_cls(cls) -> str: - return "vllm_ascend.communicator.NPUCommunicator" + return "vllm_ascend.distributed.communicator.NPUCommunicator" @classmethod def is_pin_memory_available(cls): diff --git a/vllm_ascend/worker/worker.py b/vllm_ascend/worker/worker.py index 99d89b2..306907e 100644 --- a/vllm_ascend/worker/worker.py +++ b/vllm_ascend/worker/worker.py @@ -24,8 +24,9 @@ import torch import torch.distributed from torch import nn from vllm import envs -from vllm.config import ParallelConfig, VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.config import VllmConfig +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import logger @@ -161,8 +162,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): raise RuntimeError( f"Not support device type: {self.device_config.device}") # Initialize the distributed environment. - self._init_worker_distributed_environment(self.parallel_config, - self.rank, + self._init_worker_distributed_environment(self.vllm_config, self.rank, self.distributed_init_method, self.local_rank) # Set random seed. @@ -450,12 +450,13 @@ class NPUWorker(LocalOrDistributedWorkerBase): def _init_worker_distributed_environment( self, - parallel_config: ParallelConfig, + vllm_config: VllmConfig, rank: int, distributed_init_method: Optional[str] = None, local_rank: int = -1, backend: str = "hccl") -> None: """Initialize the distributed environment.""" + parallel_config = self.parallel_config set_custom_all_reduce(not parallel_config.disable_custom_all_reduce) init_distributed_environment(parallel_config.world_size, rank, distributed_init_method, local_rank, @@ -463,6 +464,7 @@ class NPUWorker(LocalOrDistributedWorkerBase): ensure_model_parallel_initialized( parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(vllm_config) def raise_if_cache_size_invalid(num_gpu_blocks, block_size, is_attention_free, diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 6969640..ae50311 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -25,7 +25,8 @@ import torch.nn as nn import torch_npu from vllm import envs from vllm.config import VllmConfig -from vllm.distributed import (ensure_model_parallel_initialized, +from vllm.distributed import (ensure_kv_transfer_initialized, + ensure_model_parallel_initialized, init_distributed_environment, set_custom_all_reduce) from vllm.logger import logger @@ -197,6 +198,7 @@ class NPUWorker(WorkerBase): ensure_model_parallel_initialized( self.parallel_config.tensor_parallel_size, self.parallel_config.pipeline_parallel_size) + ensure_kv_transfer_initialized(self.vllm_config) def _init_profiler(self): # Torch profiler. Enabled and configured through env vars: @@ -230,4 +232,4 @@ class NPUWorker(WorkerBase): on_trace_ready=torch_npu.profiler.tensorboard_trace_handler( torch_profiler_trace_dir)) else: - return None + return None \ No newline at end of file