# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # This file is a part of the vllm-ascend project. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # import os import re import subprocess from typing import TYPE_CHECKING, List, Tuple, Union import torch import torch_npu import torchair # type: ignore from vllm.config import VllmConfig from vllm.distributed.kv_transfer.kv_connector.base import KVConnectorBase from vllm.logger import logger from vllm.sequence import IntermediateTensors import vllm_ascend.envs as envs if TYPE_CHECKING: from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata import llm_datadist # type: ignore TORCH_DTYPE_TO_NPU_DTYPE = { torch.half: llm_datadist.DataType.DT_FLOAT16, torch.float16: llm_datadist.DataType.DT_FLOAT16, torch.bfloat16: llm_datadist.DataType.DT_BF16, torch.float: llm_datadist.DataType.DT_FLOAT, torch.float32: llm_datadist.DataType.DT_FLOAT, torch.int8: llm_datadist.DataType.DT_INT8, torch.int64: llm_datadist.DataType.DT_INT64, torch.int32: llm_datadist.DataType.DT_INT32 } # Get all device ips using hccn_tool HCCN_TOOL_PATH = envs.HCCN_PATH def get_device_ips(): world_size = 8 npu_info = subprocess.run(['npu-smi', 'info', '-m'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) if npu_info.returncode != 0 or not os.path.exists(HCCN_TOOL_PATH): raise RuntimeError("No npu-smi/hccn_tool tools provided for NPU.") re_result = re.match(r'.*\n\t([0-9]+).*', npu_info.stdout) if re_result is None: raise RuntimeError("Can't find npu start index") npu_start_idx = int(re_result.group(1)) device_ip_list = [] for ip_offset in range(world_size): cmd = [ HCCN_TOOL_PATH, '-i', f'{npu_start_idx + ip_offset}', '-ip', '-g' ] device_ip_info = subprocess.run(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True) re_result = re.match(r'ipaddr:(.*)\n', device_ip_info.stdout) if re_result is None: raise RuntimeError("Can't find npu ip") device_ip = re_result.group(1) device_ip_list.append(device_ip) return device_ip_list class KVTransferEngine: def __init__(self, world_size, n_layer, role, local_rank): self.world_size = world_size self.n_layer = n_layer self.role = role self.device_ip_list = get_device_ips() self.local_rank = local_rank self.cluster_id = local_rank self.data_dist = llm_datadist.LLMDataDist(self.role, self.cluster_id) prompt_device_ids = envs.PROMPT_DEVICE_ID decode_device_ids = envs.DECODE_DEVICE_ID if prompt_device_ids is None or decode_device_ids is None: raise ValueError( "Please specify env PROMPT_DEVICE_ID or DECODE_DEVICE_ID") prompt_ids = [ int(x.strip()) for x in prompt_device_ids.split(",") if x.strip() ] decode_ids = [ int(x.strip()) for x in decode_device_ids.split(",") if x.strip() ] self.prompt_ip_list = [self.device_ip_list[i] for i in prompt_ids] self.decode_ip_list = [self.device_ip_list[i] for i in decode_ids] def prepare_data_dist(self): options = { "llm.SyncKvCacheWaitTime": envs.LLMDATADIST_SYNC_CACHE_WAIT_TIME, } if self.role == llm_datadist.LLMRole.PROMPT: options["ge.exec.deviceId"] = str(self.local_rank) options[ "llm.listenIpInfo"] = f"{self.prompt_ip_list[self.local_rank]}:{envs.LLMDATADIST_COMM_PORT}" else: options["ge.exec.deviceId"] = str(self.local_rank) self.data_dist.init(options) self.kv_transfer = self.data_dist.kv_cache_manager logger.info( f"{self.local_rank}/{self.world_size} rank data dist is ready") def make_cluster(self, prefill_ip, cluster_id=-1): cluster = llm_datadist.LLMClusterInfo() cluster.remote_cluster_id = cluster_id local_ip = self.decode_ip_list[self.local_rank] remote_ip = prefill_ip cluster.append_local_ip_info(local_ip, 0) cluster.append_remote_ip_info(remote_ip, 26000) return cluster class LLMDataDistConnector(KVConnectorBase): def __init__( self, rank: int, local_rank: int, config: VllmConfig, ): self.config = config self.tp_size = config.parallel_config.tensor_parallel_size self.rank = rank self.local_rank = local_rank if self.config.kv_transfer_config.kv_role == "kv_producer": self.role = llm_datadist.LLMRole.PROMPT elif self.config.kv_transfer_config.kv_role == "kv_consumer": self.role = llm_datadist.LLMRole.DECODER else: raise NotImplementedError( "kv_role should be inside [kv_producer, kv_consumer]") self.world_size = self.config.parallel_config.world_size self.n_layer = self.config.model_config.get_num_layers( self.config.parallel_config) self.llm_datadist_engine = KVTransferEngine(self.world_size, self.n_layer, self.role, self.local_rank) if self.role == llm_datadist.LLMRole.PROMPT: self.llm_datadist_engine.prepare_data_dist() else: self.llm_datadist_engine.prepare_data_dist() self.cluster = self.llm_datadist_engine.make_cluster( self.llm_datadist_engine.prompt_ip_list[self.local_rank], self.llm_datadist_engine.cluster_id) _, ret = self.llm_datadist_engine.data_dist.link_clusters( [self.cluster], 20000) logger.info(f"local_rank {self.local_rank} link, ret={ret}") def send_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor], hidden_or_intermediate_states: Union[torch.Tensor, IntermediateTensors] ) -> None: input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping_flat = model_input.attn_metadata.slot_mapping.flatten() start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer model_config = model_executable.model.config num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads head_size = int(hidden_size / num_attention_heads) num_layer = end_layer - start_layer # Get shape of input_tokens_tensor and kv_cache input_shape = (1, input_tokens_tensor.shape[0], 1, 1) hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) assert kv_caches[0].dtype == hidden_or_intermediate_states.dtype kv_hidden_dtype = kv_caches[0].dtype input_dtype = torch.int32 # initialize LLMDatadist data structure key_desc = llm_datadist.CacheDesc( num_layer, kv_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=1) value_desc = llm_datadist.CacheDesc( num_layer, kv_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=1) input_desc = llm_datadist.CacheDesc( 1, input_shape, TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], seq_len_dim_index=-1) hidden_desc = llm_datadist.CacheDesc( 1, hidden_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=-1) key_cache_keys = [ llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 1) ] value_cache_keys = [ llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 2) ] input_cache_keys = [ llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 3) ] hidden_cache_keys = [ llm_datadist.CacheKey(self.llm_datadist_engine.cluster_id, 0, 4) ] self.key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( key_desc, key_cache_keys) self.value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( value_desc, value_cache_keys) self.input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( input_desc, input_cache_keys) self.hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( hidden_desc, hidden_cache_keys) key_buffer_addr = self.key_buffer.per_device_tensor_addrs[0] value_buffer_addr = self.value_buffer.per_device_tensor_addrs[0] input_buffer_addr = self.input_buffer.per_device_tensor_addrs[0] hidden_buffer_addr = self.hidden_buffer.per_device_tensor_addrs[0] self.key_cache = torchair.llm_datadist.create_npu_tensors( key_desc.shape, kv_hidden_dtype, key_buffer_addr) self.value_cache = torchair.llm_datadist.create_npu_tensors( value_desc.shape, kv_hidden_dtype, value_buffer_addr) self.input_cache = torchair.llm_datadist.create_npu_tensors( input_desc.shape, input_dtype, input_buffer_addr) self.hidden_cache = torchair.llm_datadist.create_npu_tensors( hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addr) indices = torch.tensor([0], dtype=torch.int64).npu() # copy cache data into llm datadist cache using scatter update for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos].to( torch.int32) for layer_id in range(start_layer, end_layer): kv_cache = kv_caches[layer_id - start_layer] key_cache = kv_cache[0].view(-1, num_heads, head_size) value_cache = kv_cache[1].view(-1, num_heads, head_size) current_slot_mapping = slot_mapping_flat[start_pos:end_pos] # copy key into datadist k = self.key_cache[layer_id][:, start_pos:end_pos, :, :] new_k = key_cache[current_slot_mapping].unsqueeze(0) torch_npu.scatter_update_(k, indices, new_k, axis=-2) # copy value into datadist val = self.value_cache[layer_id][:, start_pos:end_pos, :, :] new_val = value_cache[current_slot_mapping].unsqueeze(0) torch_npu.scatter_update_(val, indices, new_val, axis=-2) # copy input into datadist inp = self.input_cache[0][:, start_pos:end_pos, :, :] new_inp = current_tokens.view(1, current_tokens.shape[0], 1, 1) torch_npu.scatter_update_(inp, indices, new_inp, axis=-2) # copy hidden into datadist hid = self.hidden_cache[0][:, start_pos:end_pos, :, :] hid_shape0, hid_shape1 = hidden_or_intermediate_states[ start_pos:end_pos].shape new_hid = hidden_or_intermediate_states[start_pos:end_pos].view( 1, hid_shape0, 1, hid_shape1) torch_npu.scatter_update_(hid, indices, new_hid, axis=-2) logger.info("[rank%d][P]: KV send DONE.", torch.distributed.get_rank()) def recv_kv_caches_and_hidden_states( self, model_executable: torch.nn.Module, model_input: "ModelInputForGPUWithSamplingMetadata", kv_caches: List[torch.Tensor] ) -> Tuple[Union[torch.Tensor, IntermediateTensors], bool, "ModelInputForGPUWithSamplingMetadata"]: bypass_model_exec = True input_tokens_tensor = model_input.input_tokens seq_lens = model_input.attn_metadata.seq_lens slot_mapping = model_input.attn_metadata.slot_mapping.flatten() hidden_or_intermediate_states_for_one_req = [] input_tokens_list = [] num_computed_tokens_list = [] start_pos_list = [] # get model config start_layer = model_executable.model.start_layer end_layer = model_executable.model.end_layer model_config = model_executable.model.config num_heads = int(model_config.num_key_value_heads / self.tp_size) hidden_size = model_config.hidden_size num_attention_heads = model_config.num_attention_heads head_size = int(hidden_size / num_attention_heads) num_layer = end_layer - start_layer # get input_tensor_shape and hidden_shape input_shape = (1, input_tokens_tensor.shape[0], 1, 1) hidden_shape = (1, input_tokens_tensor.shape[0], 1, hidden_size) kv_shape = (1, input_tokens_tensor.shape[0], num_heads, head_size) kv_hidden_dtype = kv_caches[0].dtype input_dtype = torch.int32 # Add LLM DataDist initialization key_desc = llm_datadist.CacheDesc( num_layer, kv_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=-1) value_desc = llm_datadist.CacheDesc( num_layer, kv_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=-1) input_desc = llm_datadist.CacheDesc( 1, input_shape, TORCH_DTYPE_TO_NPU_DTYPE[input_dtype], seq_len_dim_index=-1) hidden_desc = llm_datadist.CacheDesc( 1, hidden_shape, TORCH_DTYPE_TO_NPU_DTYPE[kv_hidden_dtype], seq_len_dim_index=-1) self.decode_key_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( key_desc) self.decode_value_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( value_desc) self.decode_input_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( input_desc) self.decode_hidden_buffer = self.llm_datadist_engine.kv_transfer.allocate_cache( hidden_desc) key_buffer_addrs = self.decode_key_buffer.per_device_tensor_addrs[0] value_buffer_addrs = self.decode_value_buffer.per_device_tensor_addrs[ 0] input_buffer_addrs = self.decode_input_buffer.per_device_tensor_addrs[ 0] hidden_buffer_addrs = self.decode_hidden_buffer.per_device_tensor_addrs[ 0] self.key_cache = torchair.llm_datadist.create_npu_tensors( key_desc.shape, kv_hidden_dtype, key_buffer_addrs) self.value_cache = torchair.llm_datadist.create_npu_tensors( value_desc.shape, kv_hidden_dtype, value_buffer_addrs) self.input_cache = torchair.llm_datadist.create_npu_tensors( input_desc.shape, input_dtype, input_buffer_addrs) self.hidden_cache = torchair.llm_datadist.create_npu_tensors( hidden_desc.shape, kv_hidden_dtype, hidden_buffer_addrs) key_cache_key = llm_datadist.CacheKeyByIdAndIndex( self.cluster.remote_cluster_id, 1, 0) value_cache_key = llm_datadist.CacheKeyByIdAndIndex( self.cluster.remote_cluster_id, 2, 0) input_cache_key = llm_datadist.CacheKeyByIdAndIndex( self.cluster.remote_cluster_id, 3, 0) hidden_cache_key = llm_datadist.CacheKeyByIdAndIndex( self.cluster.remote_cluster_id, 4, 0) self.llm_datadist_engine.kv_transfer.pull_cache( key_cache_key, self.decode_key_buffer, 0) self.llm_datadist_engine.kv_transfer.pull_cache( value_cache_key, self.decode_value_buffer, 0) self.llm_datadist_engine.kv_transfer.pull_cache( input_cache_key, self.decode_input_buffer, 0) self.llm_datadist_engine.kv_transfer.pull_cache( hidden_cache_key, self.decode_hidden_buffer, 0) keys = self.key_cache values = self.value_cache inputs = self.input_cache hidden = self.hidden_cache # enumerate different requests for idx, slen in enumerate(seq_lens): start_pos = sum(seq_lens[:idx]) end_pos = start_pos + slen current_tokens = input_tokens_tensor[start_pos:end_pos] num_tokens = slen # collecting data for rebuilding the input input_tokens_list.append(current_tokens) start_pos_list.append(start_pos) num_computed_tokens = inputs[0][0, start_pos:end_pos, 0, 0].shape[0] num_computed_tokens_list.append(num_computed_tokens) # check if both KV cache and the hidden states are received # If not, need to redo the forwarding to compute missing states if not all([(num_computed_tokens == num_tokens), hidden is not None ]): bypass_model_exec = False # update the end position based on how many tokens are cached. end_pos = start_pos + num_computed_tokens # put received KV caches into paged memory for i in range(model_executable.model.start_layer, model_executable.model.end_layer): kv_cache = kv_caches[i - model_executable.model.start_layer] key_cache, value_cache = kv_cache[0], kv_cache[1] sliced_key = keys[i - model_executable.model.start_layer][ 0, start_pos:end_pos, :, :] sliced_value = values[i - model_executable.model.start_layer][ 0, start_pos:end_pos, :, :] torch_npu._npu_reshape_and_cache( key=sliced_key, value=sliced_value, key_cache=key_cache, value_cache=value_cache, slot_indices=slot_mapping[start_pos:end_pos]) hidden_or_intermediate_states_for_one_req.append( hidden[0][0, start_pos:end_pos, 0, :]) if not bypass_model_exec: # Some of the KV cache is not retrieved # Here we will fall back to normal model forwarding # But optionally you can adjust model_input so that you only do # prefilling on those tokens that are missing KV caches. logger.info( "[rank%d][D]: Failed to receive all KVs and hidden " "states, redo model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = None else: logger.info( "[rank%d][D]: Successfully received all KVs and hidden " "states, skip model forwarding.", torch.distributed.get_rank()) hidden_or_intermediate_states = torch.cat( hidden_or_intermediate_states_for_one_req, dim=0) return hidden_or_intermediate_states, bypass_model_exec, model_input def close(self, ): self.llm_datadist_engine.data_dist.unlink_clusters([self.cluster], 5000)