diff --git a/csrc/torch_binding.cpp b/csrc/torch_binding.cpp index 9eaba723..90e7f03a 100644 --- a/csrc/torch_binding.cpp +++ b/csrc/torch_binding.cpp @@ -17,15 +17,77 @@ #include #include #include +#include #include #include +#include "torch_npu/csrc/core/npu/NPUGuard.h" #include #include "acl/acl.h" +#include "acl/acl_rt.h" #include "ops.h" #include "utils.h" #include "mla_preprocess/op_host/mla_preprocess.h" +#include +#include +#include + namespace vllm_ascend { +void swap_blocks_impl(torch::Tensor& src, torch::Tensor& dst, + const torch::Tensor& block_mapping, aclrtStream stream) { + torch::Device src_device = src.device(); + torch::Device dst_device = dst.device(); + aclrtMemcpyKind memcpy_type; + + if ((!src_device.is_cpu()) && (!dst_device.is_cpu())) { + TORCH_CHECK(src_device.index() == dst_device.index(), + "src and dst must be on the same npu"); + memcpy_type = ACL_MEMCPY_DEVICE_TO_DEVICE; + } else if ((!src_device.is_cpu()) && dst_device.is_cpu()) { + memcpy_type = ACL_MEMCPY_DEVICE_TO_HOST; + } else if (src_device.is_cpu() && (!dst_device.is_cpu())) { + memcpy_type = ACL_MEMCPY_HOST_TO_DEVICE; + } else { + TORCH_CHECK(false, "Invalid device combination, src tensor device: ", src_device, ", dst tensor device: ", dst_device); + } + + TORCH_CHECK(block_mapping.device().is_cpu(), "block_mapping must be on CPU"); + + char* src_ptr = static_cast(src.data_ptr()); + char* dst_ptr = static_cast(dst.data_ptr()); + + const int64_t block_size_in_bytes = src.element_size() * src.stride(0); + + const int64_t num_blocks = block_mapping.size(0); + const int64_t max_src_block = src.size(0); + const int64_t max_dst_block = dst.size(0); + for (size_t i = 0; i < num_blocks; i++) { + int64_t src_block_number = block_mapping[i][0].item(); + int64_t dst_block_number = block_mapping[i][1].item(); + TORCH_CHECK(src_block_number >= 0 && src_block_number <= max_src_block, + "src block index ", src_block_number, " out of range (max: ", max_src_block, ")"); + TORCH_CHECK(dst_block_number >= 0 && dst_block_number <= max_dst_block, + "dst block index ", dst_block_number, " out of range (max: ", max_dst_block, ")"); + + int64_t src_offset = src_block_number * block_size_in_bytes; + int64_t dst_offset = dst_block_number * block_size_in_bytes; + + aclrtMemcpyAsync(dst_ptr + dst_offset, block_size_in_bytes, + src_ptr + src_offset, block_size_in_bytes, + memcpy_type, stream); + } +} + +void swap_blocks(torch::Tensor &x, torch::Tensor &y, const torch::Tensor &z) +{ + + const c10_npu::OptionalNPUGuard npuGuard( + (!x.device().is_cpu()) ? x.device() : y.device() + ); + aclrtStream stream = c10_npu::getCurrentNPUStream().stream(); + swap_blocks_impl(x, y, z, stream); + return; +} AscendType get_dtype_from_torch(at::ScalarType scalarType) { @@ -511,4 +573,7 @@ TORCH_LIBRARY_EXPAND(CONCAT(_C, _ascend), ops) " Tensor q_out1, Tensor kv_cache_out1)" ); ops.impl("mla_preprocess", torch::kPrivateUse1, &vllm_ascend::mla_preprocess); + + ops.def("swap_blocks(Tensor! x, Tensor! y, Tensor z) -> ()"); + ops.impl("swap_blocks", torch::kPrivateUse1, &vllm_ascend::swap_blocks); } diff --git a/vllm_ascend/kv_offload/__init__.py b/vllm_ascend/kv_offload/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/vllm_ascend/kv_offload/cpu_npu.py b/vllm_ascend/kv_offload/cpu_npu.py new file mode 100644 index 00000000..8924ebcf --- /dev/null +++ b/vllm_ascend/kv_offload/cpu_npu.py @@ -0,0 +1,168 @@ +import numpy as np +import torch +from vllm.attention import AttentionBackend +from vllm.logger import init_logger +from vllm.utils import is_pin_memory_available +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.worker.worker import (OffloadingHandler, + TransferResult, TransferSpec) + +logger = init_logger(__name__) + + +def expand_block_ids( + block_ids: np.ndarray, + block_size_factor: int, + output: np.ndarray, + skip_count: int = 0, +): + """ + Convert a list of block IDs to a list of matching block ids, + assuming each block is composed of actual block_size_factor blocks. + Outputs to output tensor. + The first skip_count blocks will be skipped. + Note that skip_count must be less than block_size_factor. + + For example, if block_ids = [0, 1, 3] and block_size_factor = 4, + then it yields [0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15] + since 0 maps to [0, 1, 2, 3] + 1 maps to [4, 5, 6, 7] + and 3 maps to [12, 13, 14, 15] + """ + assert skip_count < block_size_factor + + first_range = np.arange(skip_count, block_size_factor) + full_range = np.arange(0, block_size_factor) + + output_idx = 0 + for i, block_id in enumerate(block_ids): + base_block_id = block_id * block_size_factor + indices = first_range if i == 0 else full_range + output_end_idx = output_idx + len(indices) + output[output_idx:output_end_idx] = base_block_id + indices + output_idx = output_end_idx + + +class CpuNpuOffloadingHandler(OffloadingHandler): + + def __init__( + self, + gpu_block_size: int, + cpu_block_size: int, + num_cpu_blocks: int, + gpu_caches: dict[str, torch.Tensor], + attn_backends: dict[str, type[AttentionBackend]], + ): + assert cpu_block_size % gpu_block_size == 0 + self.block_size_factor = cpu_block_size // gpu_block_size + + # npu streams for npu->cpu and cpu->npu + self.d2h_stream = torch.npu.Stream() + self.h2d_stream = torch.npu.Stream() + + # job_id -> transfer npu event + self.transfer_events: dict[int, torch.npu.Event] = {} + # list of npu events available for reuse + self.events_pool: list[torch.npu.Event] = [] + + pin_memory = is_pin_memory_available() + + # allocate cpu tensors + logger.info("Allocating %d CPU tensors...", len(gpu_caches)) + self.npu_tensors: list[torch.Tensor] = [] + self.cpu_tensors: list[torch.Tensor] = [] + for layer_name, gpu_tensor in gpu_caches.items(): + self.npu_tensors.append(gpu_tensor) + + gpu_shape = gpu_tensor[0].shape + + num_blocks_idx = 0 + cpu_shape = list(gpu_shape) + cpu_shape[num_blocks_idx] = num_cpu_blocks * self.block_size_factor + + logger.debug("Allocating CPU tensor of shape %r", cpu_shape) + self.cpu_tensors.append(( + torch.zeros( + cpu_shape, + dtype=gpu_tensor[0].dtype, + device="cpu", + pin_memory=pin_memory, + ), + torch.zeros( + cpu_shape, + dtype=gpu_tensor[0].dtype, + device="cpu", + pin_memory=pin_memory, + ), + )) + + def transfer_async(self, job_id: int, spec: TransferSpec) -> bool: + logger.info("start transfer_async...") + src_spec, dst_spec = spec + if isinstance(src_spec, CPULoadStoreSpec): + assert isinstance(dst_spec, GPULoadStoreSpec) + stream = self.h2d_stream + src_tensors = self.cpu_tensors + dst_tensors = self.npu_tensors + src_block_size_factor = self.block_size_factor + dst_block_size_factor = 1 + else: + assert isinstance(src_spec, GPULoadStoreSpec) + assert isinstance(dst_spec, CPULoadStoreSpec) + stream = self.d2h_stream + src_tensors = self.npu_tensors + dst_tensors = self.cpu_tensors + src_block_size_factor = 1 + dst_block_size_factor = self.block_size_factor + + src_blocks = src_spec.block_ids + dst_blocks = dst_spec.block_ids + assert src_blocks.ndim == 1 + assert dst_blocks.ndim == 1 + + dst_sub_blocks_to_skip = -src_blocks.size % dst_block_size_factor + src_sub_block_count = src_blocks.size * src_block_size_factor + + assert ( + src_sub_block_count == dst_blocks.size * dst_block_size_factor - + dst_sub_blocks_to_skip) + + src_to_dst = np.empty((src_sub_block_count, 2), dtype=np.int64) + expand_block_ids(src_blocks, src_block_size_factor, src_to_dst[:, 0]) + expand_block_ids( + dst_blocks, + dst_block_size_factor, + src_to_dst[:, 1], + skip_count=dst_sub_blocks_to_skip, + ) + src_to_dst_tensor = torch.from_numpy(src_to_dst) + + event = self.events_pool.pop( + ) if self.events_pool else torch.npu.Event() + with torch.npu.stream(stream): + for src_tensor, dst_tensor in zip(src_tensors, dst_tensors): + src_key_cache, src_value_cache = src_tensor[0], src_tensor[1] + dst_key_cache, dst_value_cache = dst_tensor[0], dst_tensor[1] + + torch.ops._C_ascend.swap_blocks(src_key_cache, dst_key_cache, + src_to_dst_tensor) + torch.ops._C_ascend.swap_blocks(src_value_cache, + dst_value_cache, + src_to_dst_tensor) + + event.record(stream) + + self.transfer_events[job_id] = event + + # success + return True + + def get_finished(self) -> list[TransferResult]: + results: list[TransferResult] = [] + for job_id, event in self.transfer_events.items(): + if event.query(): + results.append((job_id, True)) + self.events_pool.append(event) + for job_id, _ in results: + del self.transfer_events[job_id] + return results diff --git a/vllm_ascend/kv_offload/npu.py b/vllm_ascend/kv_offload/npu.py new file mode 100644 index 00000000..9f80237b --- /dev/null +++ b/vllm_ascend/kv_offload/npu.py @@ -0,0 +1,71 @@ +from collections.abc import Iterator +from typing import Optional + +import torch +from vllm.config import VllmConfig, get_layers_from_vllm_config +from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase +from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager +from vllm.v1.kv_offload.backends.cpu import CPUBackend +from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager +from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec +from vllm.v1.kv_offload.spec import OffloadingSpec +from vllm.v1.kv_offload.worker.worker import OffloadingHandler + +from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler + + +class NPUOffloadingSpec(OffloadingSpec): + + def __init__(self, vllm_config: VllmConfig): + super().__init__(vllm_config) + + num_cpu_blocks = self.extra_config.get("num_cpu_blocks") + if not num_cpu_blocks: + raise Exception( + "num_cpu_blocks must be specified in kv_connector_extra_config" + ) + self.num_cpu_blocks: int = num_cpu_blocks + + # scheduler-side + self._manager: Optional[OffloadingManager] = None + + # worker-side + self._handler: Optional[OffloadingHandler] = None + + def get_manager(self) -> OffloadingManager: + if not self._manager: + kv_events_config = self.vllm_config.kv_events_config + enable_events = (kv_events_config is not None + and kv_events_config.enable_kv_cache_events) + self._manager = LRUOffloadingManager( + CPUBackend(block_size=self.offloaded_block_size, + num_blocks=self.num_cpu_blocks), + enable_events=enable_events, + ) + return self._manager + + def get_handlers( + self, kv_caches: dict[str, torch.Tensor] + ) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec], + OffloadingHandler]]: + if not self._handler: + layer_names = list(kv_caches.keys()) + layers = get_layers_from_vllm_config(self.vllm_config, + AttentionLayerBase, + layer_names) + attn_backends = { + layer_name: layers[layer_name].get_attn_backend() + for layer_name in layer_names + } + + self._handler = CpuNpuOffloadingHandler( + attn_backends=attn_backends, + gpu_block_size=self.gpu_block_size, + cpu_block_size=self.offloaded_block_size, + num_cpu_blocks=self.num_cpu_blocks, + gpu_caches=kv_caches, + ) + + assert self._handler is not None + yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler + yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler