### What this PR does / why we need it? This pr is for https://github.com/vllm-project/vllm-ascend/issues/3241 , which is in-house solution for offloading KV cache data from the GPU memory to other medium (in particular, CPU memory)。Previous solutions required reliance on third-party components, which had issues with compatibility between different versions. ### How was this patch tested? use the following script for testing: export CUDA_VISIBLE_DEVICES=0 export TP=1 export MODEL_PATH=/model/Qwen3-14B export MODEL_NAME=Qwen3-14B export PORT=10000 #export ASCEND_LAUNCH_BLOCKING=1 #export ASCEND_SLOG_PRINT_TO_STDOUT=1 python3 -m vllm.entrypoints.openai.api_server --host 0.0.0.0 --port ${PORT} --dtype bfloat16 --model ${MODEL_PATH} --served-model-name ${MODEL_NAME} --tensor-parallel-size ${TP} --gpu-memory-utilization 0.7 --max-model-len 32768 --trust-remote-code --disable-log-requests \ --block-size 128 \ --kv-transfer-config '{"kv_connector":"OffloadingConnector","kv_role":"kv_both","kv_connector_extra_config":{"block_size": 128, "num_cpu_blocks": 1000, "spec_name":"NPUOffloadingSpec", "spec_module_path": "vllm_ascend.kv_offload.npu"}}' - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: HF-001 <1670186653@qq.com>
72 lines
2.8 KiB
Python
72 lines
2.8 KiB
Python
from collections.abc import Iterator
|
|
from typing import Optional
|
|
|
|
import torch
|
|
from vllm.config import VllmConfig, get_layers_from_vllm_config
|
|
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
|
from vllm.v1.kv_offload.abstract import LoadStoreSpec, OffloadingManager
|
|
from vllm.v1.kv_offload.backends.cpu import CPUBackend
|
|
from vllm.v1.kv_offload.lru_manager import LRUOffloadingManager
|
|
from vllm.v1.kv_offload.mediums import CPULoadStoreSpec, GPULoadStoreSpec
|
|
from vllm.v1.kv_offload.spec import OffloadingSpec
|
|
from vllm.v1.kv_offload.worker.worker import OffloadingHandler
|
|
|
|
from vllm_ascend.kv_offload.cpu_npu import CpuNpuOffloadingHandler
|
|
|
|
|
|
class NPUOffloadingSpec(OffloadingSpec):
|
|
|
|
def __init__(self, vllm_config: VllmConfig):
|
|
super().__init__(vllm_config)
|
|
|
|
num_cpu_blocks = self.extra_config.get("num_cpu_blocks")
|
|
if not num_cpu_blocks:
|
|
raise Exception(
|
|
"num_cpu_blocks must be specified in kv_connector_extra_config"
|
|
)
|
|
self.num_cpu_blocks: int = num_cpu_blocks
|
|
|
|
# scheduler-side
|
|
self._manager: Optional[OffloadingManager] = None
|
|
|
|
# worker-side
|
|
self._handler: Optional[OffloadingHandler] = None
|
|
|
|
def get_manager(self) -> OffloadingManager:
|
|
if not self._manager:
|
|
kv_events_config = self.vllm_config.kv_events_config
|
|
enable_events = (kv_events_config is not None
|
|
and kv_events_config.enable_kv_cache_events)
|
|
self._manager = LRUOffloadingManager(
|
|
CPUBackend(block_size=self.offloaded_block_size,
|
|
num_blocks=self.num_cpu_blocks),
|
|
enable_events=enable_events,
|
|
)
|
|
return self._manager
|
|
|
|
def get_handlers(
|
|
self, kv_caches: dict[str, torch.Tensor]
|
|
) -> Iterator[tuple[type[LoadStoreSpec], type[LoadStoreSpec],
|
|
OffloadingHandler]]:
|
|
if not self._handler:
|
|
layer_names = list(kv_caches.keys())
|
|
layers = get_layers_from_vllm_config(self.vllm_config,
|
|
AttentionLayerBase,
|
|
layer_names)
|
|
attn_backends = {
|
|
layer_name: layers[layer_name].get_attn_backend()
|
|
for layer_name in layer_names
|
|
}
|
|
|
|
self._handler = CpuNpuOffloadingHandler(
|
|
attn_backends=attn_backends,
|
|
gpu_block_size=self.gpu_block_size,
|
|
cpu_block_size=self.offloaded_block_size,
|
|
num_cpu_blocks=self.num_cpu_blocks,
|
|
gpu_caches=kv_caches,
|
|
)
|
|
|
|
assert self._handler is not None
|
|
yield GPULoadStoreSpec, CPULoadStoreSpec, self._handler
|
|
yield CPULoadStoreSpec, GPULoadStoreSpec, self._handler
|