[Feature]cpu offload connector (#1659)
This PR implements cpu offload connector to enable NPU kv cache offload
to host DRAM.
- vLLM version: v0.10.2
- vLLM main:
5aeb925452
Signed-off-by: lidenghui <lidenghui1110@gmail.com>
Signed-off-by: AlvisGong <gwly0401@163.com>
Signed-off-by: CalvinXKY <kyxiezju@163.com>
Co-authored-by: AlvisGong <gwly0401@163.com>
This commit is contained in:
@@ -554,7 +554,11 @@ class TestAscendMLAImpl(TestBase):
|
|||||||
self.impl.num_kv_heads = self.impl.num_heads
|
self.impl.num_kv_heads = self.impl.num_heads
|
||||||
|
|
||||||
decode_res, prefill_res = self.impl._mla_preprocess(
|
decode_res, prefill_res = self.impl._mla_preprocess(
|
||||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False)
|
"mock_layer",
|
||||||
|
hidden_states,
|
||||||
|
kv_cache,
|
||||||
|
attn_metadata,
|
||||||
|
need_gather_q_kv=False)
|
||||||
|
|
||||||
self.assertIsNotNone(decode_res)
|
self.assertIsNotNone(decode_res)
|
||||||
self.assertIsNotNone(prefill_res)
|
self.assertIsNotNone(prefill_res)
|
||||||
|
|||||||
@@ -26,53 +26,21 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
|
|||||||
AttentionLayer, AttentionType)
|
AttentionLayer, AttentionType)
|
||||||
from vllm.attention.backends.utils import CommonAttentionState
|
from vllm.attention.backends.utils import CommonAttentionState
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
|
||||||
has_kv_transfer_group,
|
|
||||||
is_v1_kv_transfer_group)
|
|
||||||
from vllm.forward_context import ForwardContext, get_forward_context
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
from vllm.utils import cdiv, direct_register_custom_op
|
from vllm.utils import cdiv, direct_register_custom_op
|
||||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.kv_cache_interface import AttentionSpec
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
|
maybe_save_kv_layer_to_connector,
|
||||||
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.compilation.acl_graph import get_graph_params
|
from vllm_ascend.compilation.acl_graph import get_graph_params
|
||||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
from vllm_ascend.ops.attention import vanilla_chunked_prefill
|
||||||
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p,
|
||||||
nd_to_nz_2d, nd_to_nz_spec)
|
nd_to_nz_2d, nd_to_nz_spec)
|
||||||
|
|
||||||
|
|
||||||
def wait_for_kv_layer_from_connector(layer_name: str):
|
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
|
||||||
return
|
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
if attn_metadata is None:
|
|
||||||
return
|
|
||||||
# TODO: assert ascendMetadata
|
|
||||||
connector.wait_for_layer_load(layer_name)
|
|
||||||
|
|
||||||
|
|
||||||
def maybe_save_kv_layer_to_connector(
|
|
||||||
layer_name: str,
|
|
||||||
kv_cache_layer: List[torch.Tensor],
|
|
||||||
):
|
|
||||||
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
|
||||||
return
|
|
||||||
|
|
||||||
connector = get_kv_transfer_group()
|
|
||||||
|
|
||||||
forward_context: ForwardContext = get_forward_context()
|
|
||||||
attn_metadata = forward_context.attn_metadata
|
|
||||||
if attn_metadata is None:
|
|
||||||
return
|
|
||||||
# TODO: assert ascendMetadata
|
|
||||||
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
|
||||||
|
|
||||||
|
|
||||||
class AscendAttentionBackend(AttentionBackend):
|
class AscendAttentionBackend(AttentionBackend):
|
||||||
accept_output_buffer: bool = True
|
accept_output_buffer: bool = True
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,9 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport
|
|||||||
from vllm_ascend.ascend_config import get_ascend_config
|
from vllm_ascend.ascend_config import get_ascend_config
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata,
|
||||||
split_decodes_and_prefills)
|
maybe_save_kv_layer_to_connector,
|
||||||
|
split_decodes_and_prefills,
|
||||||
|
wait_for_kv_layer_from_connector)
|
||||||
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig
|
||||||
from vllm_ascend.multistream.context import get_multistream_comm_context
|
from vllm_ascend.multistream.context import get_multistream_comm_context
|
||||||
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn
|
||||||
@@ -858,8 +860,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
current_ms_metadata.before_comm_event.wait()
|
current_ms_metadata.before_comm_event.wait()
|
||||||
return self._v_up_proj(attn_output)
|
return self._v_up_proj(attn_output)
|
||||||
|
|
||||||
def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata,
|
def _mla_preprocess(self, layer_name, hidden_states, kv_cache,
|
||||||
need_gather_q_kv):
|
attn_metadata, need_gather_q_kv):
|
||||||
# MLA Preprocess:
|
# MLA Preprocess:
|
||||||
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
# 1. Perform q_a_proj and q_a_layernorm to obtain q_c
|
||||||
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
# 2. Perform kv_a_proj_with_mqa to obtain kv_no_split
|
||||||
@@ -888,6 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
kv_no_split = get_tp_group().all_gather(kv_no_split, 0)
|
||||||
decode_preprocess_res = None
|
decode_preprocess_res = None
|
||||||
prefill_preprocess_res = None
|
prefill_preprocess_res = None
|
||||||
|
if has_prefill:
|
||||||
|
wait_for_kv_layer_from_connector(layer_name)
|
||||||
# Preprocess for decode tokens
|
# Preprocess for decode tokens
|
||||||
if has_decode:
|
if has_decode:
|
||||||
decode_q_c = q_c[:num_decode_tokens]
|
decode_q_c = q_c[:num_decode_tokens]
|
||||||
@@ -934,6 +938,7 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
layer_name,
|
||||||
hidden_states: torch.Tensor, # query in unified attn
|
hidden_states: torch.Tensor, # query in unified attn
|
||||||
kv_cache: Tuple[torch.Tensor],
|
kv_cache: Tuple[torch.Tensor],
|
||||||
attn_metadata: M,
|
attn_metadata: M,
|
||||||
@@ -960,7 +965,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
|
|
||||||
# MLA Preprocess
|
# MLA Preprocess
|
||||||
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess(
|
||||||
hidden_states, kv_cache, attn_metadata, need_gather_q_kv)
|
layer_name, hidden_states, kv_cache, attn_metadata,
|
||||||
|
need_gather_q_kv)
|
||||||
|
|
||||||
if decode_preprocess_res is not None:
|
if decode_preprocess_res is not None:
|
||||||
# MLA Preprocess for decoding
|
# MLA Preprocess for decoding
|
||||||
@@ -1018,4 +1024,8 @@ class AscendMLAImpl(MLAAttentionImpl):
|
|||||||
is_force_scatter=self.enable_shared_expert_dp)[0]
|
is_force_scatter=self.enable_shared_expert_dp)[0]
|
||||||
current_ms_metadata.after_comm_event.record()
|
current_ms_metadata.after_comm_event.record()
|
||||||
del o_proj_input
|
del o_proj_input
|
||||||
|
|
||||||
|
has_prefill = attn_metadata.num_prefills > 0
|
||||||
|
if has_prefill:
|
||||||
|
maybe_save_kv_layer_to_connector(layer_name, list(kv_cache))
|
||||||
return output_padded
|
return output_padded
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any, List
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
|
||||||
|
has_kv_transfer_group,
|
||||||
|
is_v1_kv_transfer_group)
|
||||||
|
from vllm.forward_context import ForwardContext, get_forward_context
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -100,3 +104,34 @@ def split_decodes_and_prefills(
|
|||||||
num_decode_tokens = query_start_loc[first_prefill].item()
|
num_decode_tokens = query_start_loc[first_prefill].item()
|
||||||
num_prefill_tokens = num_tokens - num_decode_tokens
|
num_prefill_tokens = num_tokens - num_decode_tokens
|
||||||
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens)
|
||||||
|
|
||||||
|
|
||||||
|
def wait_for_kv_layer_from_connector(layer_name: str):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
# TODO: assert ascendMetadata
|
||||||
|
connector.wait_for_layer_load(layer_name)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_save_kv_layer_to_connector(
|
||||||
|
layer_name: str,
|
||||||
|
kv_cache_layer: List[torch.Tensor],
|
||||||
|
):
|
||||||
|
if not has_kv_transfer_group() or not is_v1_kv_transfer_group():
|
||||||
|
return
|
||||||
|
|
||||||
|
connector = get_kv_transfer_group()
|
||||||
|
|
||||||
|
forward_context: ForwardContext = get_forward_context()
|
||||||
|
attn_metadata = forward_context.attn_metadata
|
||||||
|
if attn_metadata is None:
|
||||||
|
return
|
||||||
|
# TODO: assert ascendMetadata
|
||||||
|
connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata)
|
||||||
|
|||||||
457
vllm_ascend/distributed/cpu_offload_connector.py
Normal file
457
vllm_ascend/distributed/cpu_offload_connector.py
Normal file
@@ -0,0 +1,457 @@
|
|||||||
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
|
import copy
|
||||||
|
import queue
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import TYPE_CHECKING, Any, Optional, Sequence
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from vllm.attention import AttentionType
|
||||||
|
from vllm.attention.layer import Attention
|
||||||
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
|
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
|
||||||
|
from vllm.distributed.parallel_state import get_pp_group, get_tp_group
|
||||||
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
||||||
|
from vllm.utils import logger
|
||||||
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.cpu_offload_manager.metadata import (
|
||||||
|
MetadataServer, MetadataServerProc, MLAConfig)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
|
from vllm.forward_context import ForwardContext
|
||||||
|
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ReqMeta:
|
||||||
|
gpu_block_ids: list[int]
|
||||||
|
cpu_block_ids: list[int]
|
||||||
|
num_scheduled_tokens: int
|
||||||
|
num_computed_tokens: int
|
||||||
|
num_gpu_computed_tokens: int
|
||||||
|
num_cpu_computed_tokens: int
|
||||||
|
|
||||||
|
def update(self, other: "ReqMeta"):
|
||||||
|
self.gpu_block_ids.extend(other.gpu_block_ids)
|
||||||
|
self.cpu_block_ids.extend(other.cpu_block_ids)
|
||||||
|
self.num_scheduled_tokens = other.num_scheduled_tokens
|
||||||
|
self.num_computed_tokens = other.num_computed_tokens
|
||||||
|
self.num_gpu_computed_tokens = other.num_gpu_computed_tokens
|
||||||
|
self.num_cpu_computed_tokens = other.num_cpu_computed_tokens
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
|
||||||
|
requests: dict[str, ReqMeta]
|
||||||
|
finished_req_ids: set[str]
|
||||||
|
|
||||||
|
|
||||||
|
class CPUOffloadingConnector(KVConnectorBase_V1):
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
|
||||||
|
if not vllm_config.cache_config.enable_prefix_caching:
|
||||||
|
self.connector_scheduler: Optional[
|
||||||
|
CPUOffloadingConnectorScheduler] = None
|
||||||
|
self.connector_worker: Optional[
|
||||||
|
CPUOffloadingConnectorWorker] = None
|
||||||
|
elif role == KVConnectorRole.SCHEDULER:
|
||||||
|
self.connector_scheduler = CPUOffloadingConnectorScheduler(
|
||||||
|
vllm_config)
|
||||||
|
self.connector_worker = None
|
||||||
|
elif role == KVConnectorRole.WORKER:
|
||||||
|
self.connector_scheduler = None
|
||||||
|
self.connector_worker = CPUOffloadingConnectorWorker(vllm_config)
|
||||||
|
|
||||||
|
# ==============================
|
||||||
|
# Worker-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
def bind_connector_metadata(
|
||||||
|
self, connector_metadata: KVConnectorMetadata) -> None:
|
||||||
|
if self.connector_worker is not None:
|
||||||
|
assert isinstance(connector_metadata,
|
||||||
|
CPUOffloadingConnectorMetadata)
|
||||||
|
self.connector_worker.bind_connector_metadata(connector_metadata)
|
||||||
|
|
||||||
|
def clear_connector_metadata(self) -> None:
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
self.connector_worker.clear_connector_metadata()
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||||
|
if self.connector_worker is not None:
|
||||||
|
self.connector_worker.register_kv_caches(kv_caches)
|
||||||
|
|
||||||
|
def start_load_kv(self, forward_context: "ForwardContext",
|
||||||
|
**kwargs) -> None:
|
||||||
|
if self.connector_worker is not None:
|
||||||
|
self.connector_worker.start_load_kv()
|
||||||
|
|
||||||
|
def wait_for_layer_load(self, layer_name: str) -> None:
|
||||||
|
if self.connector_worker is not None:
|
||||||
|
self.connector_worker.wait_for_layer_load()
|
||||||
|
|
||||||
|
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
|
||||||
|
attn_metadata: "AttentionMetadata", **kwargs) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def wait_for_save(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_finished(
|
||||||
|
self, finished_req_ids: set[str]
|
||||||
|
) -> tuple[Optional[set[str]], Optional[set[str]]]:
|
||||||
|
assert self.connector_worker is not None
|
||||||
|
return self.connector_worker.get_finished(), None
|
||||||
|
|
||||||
|
# Scheduler-side methods
|
||||||
|
# ==============================
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, request: "Request",
|
||||||
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
return self.connector_scheduler.get_num_new_matched_tokens(
|
||||||
|
request, num_computed_tokens)
|
||||||
|
return 0, False
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request",
|
||||||
|
blocks: "KVCacheBlocks",
|
||||||
|
num_external_tokens: int):
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
return self.connector_scheduler.update_state_after_alloc(request)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
return self.connector_scheduler.build_connector_meta(
|
||||||
|
scheduler_output)
|
||||||
|
return KVConnectorMetadata()
|
||||||
|
|
||||||
|
def request_finished(
|
||||||
|
self, request: "Request",
|
||||||
|
block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]:
|
||||||
|
if self.connector_scheduler is not None:
|
||||||
|
self.connector_scheduler.request_finished(request)
|
||||||
|
return True, None
|
||||||
|
|
||||||
|
|
||||||
|
class CPUOffloadingConnectorScheduler:
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
|
logger.info("init CPUOffloadingConnectorScheduler")
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.use_mla = vllm_config.model_config.use_mla
|
||||||
|
self.num_gpu_computed_tokens: dict[str, int] = {}
|
||||||
|
self.num_cpu_computed_tokens: dict[str, int] = {}
|
||||||
|
self.allocated_req_ids: set[str] = set()
|
||||||
|
self.finished_req_ids: list[str] = []
|
||||||
|
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||||
|
self.zmq_rpc_client.call("post_init")
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config(
|
||||||
|
"swap_in_threshold", 0)
|
||||||
|
else:
|
||||||
|
self.swap_in_threshold = 0
|
||||||
|
logger.info(f"swap_in_threshold: {self.swap_in_threshold}")
|
||||||
|
|
||||||
|
def get_num_new_matched_tokens(
|
||||||
|
self, ori_request: "Request",
|
||||||
|
num_computed_tokens: int) -> tuple[int, bool]:
|
||||||
|
request = copy.deepcopy(ori_request)
|
||||||
|
request.get_hash_new_full_blocks = None
|
||||||
|
num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call(
|
||||||
|
"get_matched_num_and_touch", request)
|
||||||
|
self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens
|
||||||
|
self.num_cpu_computed_tokens[
|
||||||
|
request.request_id] = num_cpu_computed_tokens
|
||||||
|
if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold:
|
||||||
|
return num_cpu_computed_tokens - num_computed_tokens, load_async
|
||||||
|
else:
|
||||||
|
return 0, load_async
|
||||||
|
|
||||||
|
def update_state_after_alloc(self, request: "Request"):
|
||||||
|
self.allocated_req_ids.add(request.request_id)
|
||||||
|
|
||||||
|
def build_connector_meta(
|
||||||
|
self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata:
|
||||||
|
num_tokens = {}
|
||||||
|
# process scheduled_new_reqs
|
||||||
|
for req in scheduler_output.scheduled_new_reqs:
|
||||||
|
req_id = req.req_id
|
||||||
|
num_tokens[req_id] = (
|
||||||
|
req.num_computed_tokens +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
|
||||||
|
# process scheduled_cached_reqs
|
||||||
|
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||||
|
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||||
|
num_tokens[req_id] = (
|
||||||
|
cached_reqs.num_computed_tokens[idx] +
|
||||||
|
scheduler_output.num_scheduled_tokens[req_id])
|
||||||
|
|
||||||
|
unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() -
|
||||||
|
self.allocated_req_ids -
|
||||||
|
scheduler_output.num_scheduled_tokens.keys())
|
||||||
|
new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots",
|
||||||
|
num_tokens,
|
||||||
|
unallocated_req_ids)
|
||||||
|
metadata = CPUOffloadingConnectorMetadata(
|
||||||
|
requests={},
|
||||||
|
finished_req_ids=set(self.finished_req_ids),
|
||||||
|
)
|
||||||
|
for req in scheduler_output.scheduled_new_reqs:
|
||||||
|
req_id = req.req_id
|
||||||
|
gpu_block_ids = req.block_ids[0]
|
||||||
|
metadata.requests[req_id] = ReqMeta(
|
||||||
|
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||||
|
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||||
|
num_scheduled_tokens=scheduler_output.
|
||||||
|
num_scheduled_tokens[req_id],
|
||||||
|
num_computed_tokens=req.num_computed_tokens,
|
||||||
|
num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id],
|
||||||
|
num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id])
|
||||||
|
|
||||||
|
for idx, req_id in enumerate(cached_reqs.req_ids):
|
||||||
|
gpu_block_ids = cached_reqs.new_block_ids[idx]
|
||||||
|
metadata.requests[req_id] = ReqMeta(
|
||||||
|
gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids,
|
||||||
|
cpu_block_ids=new_cpu_block_ids.get(req_id, []),
|
||||||
|
num_scheduled_tokens=scheduler_output.
|
||||||
|
num_scheduled_tokens[req_id],
|
||||||
|
num_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||||
|
num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx],
|
||||||
|
num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx])
|
||||||
|
self.num_gpu_computed_tokens.clear()
|
||||||
|
self.num_cpu_computed_tokens.clear()
|
||||||
|
self.allocated_req_ids.clear()
|
||||||
|
self.finished_req_ids.clear()
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
def request_finished(self, ori_request: "Request"):
|
||||||
|
request = copy.deepcopy(ori_request)
|
||||||
|
request.get_hash_new_full_blocks = None
|
||||||
|
self.finished_req_ids.append(request.request_id)
|
||||||
|
# inform metadata server to record request, and free it after finish sending
|
||||||
|
self.zmq_rpc_client.call("record_request_cache_and_free_slots",
|
||||||
|
request)
|
||||||
|
|
||||||
|
|
||||||
|
class CPUOffloadingConnectorWorker:
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
|
logger.info("init CPUOffloadingConnectorWorker")
|
||||||
|
self.vllm_config = vllm_config
|
||||||
|
self.block_size = vllm_config.cache_config.block_size
|
||||||
|
self.pp_rank = get_pp_group().rank_in_group
|
||||||
|
self.tp_group = get_tp_group()
|
||||||
|
self.tp_rank = self.tp_group.rank_in_group
|
||||||
|
self.tp_world_size = self.tp_group.world_size
|
||||||
|
self.use_mla = vllm_config.model_config.use_mla
|
||||||
|
|
||||||
|
self.requests: dict[str, ReqMeta] = {}
|
||||||
|
self.load_stream = torch.npu.Stream()
|
||||||
|
self.save_stream = torch.npu.Stream()
|
||||||
|
self.zmq_rpc_client = MetadataServer.ZMQRPCClient()
|
||||||
|
self.load_block_mapping: list[tuple[int, int]] = []
|
||||||
|
self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue()
|
||||||
|
self.save_output_queue: queue.Queue[str] = queue.Queue()
|
||||||
|
self.save_thread = threading.Thread(target=self._save_listener)
|
||||||
|
self.save_thread.start()
|
||||||
|
self.done_sending_count: defaultdict[str, int] = defaultdict(int)
|
||||||
|
|
||||||
|
# start metadata server to init cpu_kv_cache_manager and handle rpc requests
|
||||||
|
# all dp shared the same metadata server, only start the process on data_rank 0
|
||||||
|
if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0:
|
||||||
|
config = VllmConfig()
|
||||||
|
config.cache_config = vllm_config.cache_config
|
||||||
|
config.parallel_config = vllm_config.parallel_config
|
||||||
|
config.kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
self.init_metadata_server(config)
|
||||||
|
self._wait_for_metadata_process_start()
|
||||||
|
|
||||||
|
def init_metadata_server(self, vllm_config: VllmConfig):
|
||||||
|
self.metadata_thread = threading.Thread(
|
||||||
|
target=MetadataServerProc.run_metadata_server,
|
||||||
|
args=(vllm_config, ),
|
||||||
|
)
|
||||||
|
self.metadata_thread.daemon = True
|
||||||
|
self.metadata_thread.start()
|
||||||
|
|
||||||
|
def _wait_for_metadata_process_start(self):
|
||||||
|
# TODO: wait for metadata server to start, add a rpc to check if ready
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
if self.zmq_rpc_client.call("ready"):
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
logger.info(f"wait for metadata server to start, error: {e}")
|
||||||
|
time.sleep(1)
|
||||||
|
|
||||||
|
def bind_connector_metadata(
|
||||||
|
self, connector_metadata: CPUOffloadingConnectorMetadata) -> None:
|
||||||
|
for req_id, req in connector_metadata.requests.items():
|
||||||
|
if req_id in self.requests:
|
||||||
|
self.requests[req_id].update(req)
|
||||||
|
req = self.requests[req_id]
|
||||||
|
else:
|
||||||
|
self.requests[req_id] = req
|
||||||
|
for i in range(req.num_gpu_computed_tokens // self.block_size,
|
||||||
|
req.num_computed_tokens // self.block_size):
|
||||||
|
self.load_block_mapping.append(
|
||||||
|
(req.cpu_block_ids[i], req.gpu_block_ids[i]))
|
||||||
|
for req_id in connector_metadata.finished_req_ids:
|
||||||
|
if req_id in self.requests:
|
||||||
|
self.save_input_queue.put((req_id, self.requests[req_id]))
|
||||||
|
|
||||||
|
def clear_connector_metadata(self) -> None:
|
||||||
|
self.load_block_mapping.clear()
|
||||||
|
|
||||||
|
def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]):
|
||||||
|
self.gpu_kv_caches = kv_caches
|
||||||
|
model_config = self.vllm_config.model_config
|
||||||
|
mla_config: Optional[MLAConfig] = None
|
||||||
|
if model_config.use_mla:
|
||||||
|
mla_config = MLAConfig(
|
||||||
|
model_config.hf_text_config.kv_lora_rank,
|
||||||
|
model_config.hf_text_config.qk_rope_head_dim)
|
||||||
|
self.cpu_kv_caches = list(
|
||||||
|
self.zmq_rpc_client.call(
|
||||||
|
"init_cpu_kv_caches",
|
||||||
|
self.pp_rank,
|
||||||
|
self.tp_rank,
|
||||||
|
get_kv_cache_spec(self.vllm_config),
|
||||||
|
mla_config,
|
||||||
|
).values())
|
||||||
|
|
||||||
|
def start_load_kv(self) -> None:
|
||||||
|
self.current_layer = 0
|
||||||
|
self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values())
|
||||||
|
self.load_kv_layer(0)
|
||||||
|
|
||||||
|
def wait_for_layer_load(self) -> None:
|
||||||
|
# TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug.
|
||||||
|
self.load_stream.synchronize()
|
||||||
|
self.current_layer += 1
|
||||||
|
self.load_kv_layer(self.current_layer)
|
||||||
|
|
||||||
|
def load_kv_layer(self, layer: int):
|
||||||
|
if layer == len(self.gpu_kv_caches):
|
||||||
|
return
|
||||||
|
gpu_kv_caches = next(self.gpu_kv_caches_load_iter)
|
||||||
|
cpu_kv_caches = self.cpu_kv_caches[layer]
|
||||||
|
with torch.npu.stream(self.load_stream):
|
||||||
|
for cpu_block_id, gpu_block_id in self.load_block_mapping:
|
||||||
|
for gpu_layer_part, cpu_layer_part in zip(
|
||||||
|
gpu_kv_caches, cpu_kv_caches):
|
||||||
|
gpu_layer_part[gpu_block_id].copy_(
|
||||||
|
cpu_layer_part[cpu_block_id], non_blocking=True)
|
||||||
|
|
||||||
|
def get_finished(self) -> set[str]:
|
||||||
|
done_sending: set[str] = set()
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
id = self.save_output_queue.get_nowait()
|
||||||
|
except queue.Empty:
|
||||||
|
break
|
||||||
|
done_sending.add(id)
|
||||||
|
for id in done_sending:
|
||||||
|
del self.requests[id]
|
||||||
|
if self.tp_world_size == 1:
|
||||||
|
return done_sending
|
||||||
|
if self.tp_rank == 0:
|
||||||
|
for req_id in done_sending:
|
||||||
|
self.done_sending_count[req_id] += 1
|
||||||
|
other_ranks_finished_ids: list[str] = []
|
||||||
|
for i in range(1, self.tp_world_size):
|
||||||
|
other_ranks_finished_ids.extend(
|
||||||
|
self.tp_group.recv_object(src=i))
|
||||||
|
for req_id in other_ranks_finished_ids:
|
||||||
|
self.done_sending_count[req_id] += 1
|
||||||
|
all_done_sending: set[str] = set()
|
||||||
|
for req_id in list(self.done_sending_count.keys()):
|
||||||
|
if self.done_sending_count[req_id] == self.tp_world_size:
|
||||||
|
del self.done_sending_count[req_id]
|
||||||
|
all_done_sending.add(req_id)
|
||||||
|
# release cpu_kv_cache after request sending finished
|
||||||
|
# to avoid rpc blocking, use thread to call rpc asynchronously
|
||||||
|
sending_finished_thread = threading.Thread(
|
||||||
|
target=self._sending_finished, args=(all_done_sending, ))
|
||||||
|
sending_finished_thread.daemon = True
|
||||||
|
sending_finished_thread.start()
|
||||||
|
|
||||||
|
return all_done_sending
|
||||||
|
else:
|
||||||
|
self.tp_group.send_object(done_sending, dst=0)
|
||||||
|
return done_sending
|
||||||
|
|
||||||
|
def _sending_finished(self, all_done_sending):
|
||||||
|
for req_id in all_done_sending:
|
||||||
|
logger.debug(f"call cache_and_free_slots for req_id: {req_id}")
|
||||||
|
self.zmq_rpc_client.call("cache_and_free_slots", req_id)
|
||||||
|
|
||||||
|
def _save_listener(self):
|
||||||
|
save_block_mapping = []
|
||||||
|
while True:
|
||||||
|
req_id, req = self.save_input_queue.get()
|
||||||
|
for i in range(
|
||||||
|
req.num_cpu_computed_tokens // self.block_size,
|
||||||
|
min((req.num_computed_tokens + req.num_scheduled_tokens) //
|
||||||
|
self.block_size, len(req.cpu_block_ids))):
|
||||||
|
save_block_mapping.append(
|
||||||
|
(req.gpu_block_ids[i], req.cpu_block_ids[i]))
|
||||||
|
with torch.npu.stream(self.save_stream):
|
||||||
|
# MLA: kv_layer is tuple[tensor, tensor] means (rope, nope).
|
||||||
|
# non-MLA: kv_layer is list[tensor], typically means [k, v].
|
||||||
|
if self.use_mla:
|
||||||
|
start, step = self.tp_rank, self.tp_world_size
|
||||||
|
else:
|
||||||
|
start, step = 0, 1
|
||||||
|
for i in range(start, len(save_block_mapping), step):
|
||||||
|
gpu_block_id, cpu_block_id = save_block_mapping[i]
|
||||||
|
for cpu_kv_caches, gpu_kv_caches in zip(
|
||||||
|
self.cpu_kv_caches, self.gpu_kv_caches.values()):
|
||||||
|
for cpu_layer_part, gpu_layer_part in zip(
|
||||||
|
cpu_kv_caches, gpu_kv_caches):
|
||||||
|
cpu_layer_part[cpu_block_id].copy_(
|
||||||
|
gpu_layer_part[gpu_block_id],
|
||||||
|
non_blocking=True)
|
||||||
|
self.save_stream.synchronize()
|
||||||
|
self.save_output_queue.put(req_id)
|
||||||
|
save_block_mapping.clear()
|
||||||
|
|
||||||
|
|
||||||
|
# Copied from vllm_ascend/worker/model_runner_v1.py.
|
||||||
|
def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]:
|
||||||
|
forward_ctx = vllm_config.compilation_config.static_forward_context
|
||||||
|
block_size = vllm_config.cache_config.block_size
|
||||||
|
use_mla = vllm_config.model_config.use_mla
|
||||||
|
kv_cache_spec: dict[str, KVCacheSpec] = {}
|
||||||
|
for layer_name, attn_module in forward_ctx.items():
|
||||||
|
if isinstance(attn_module, FusedMoE):
|
||||||
|
continue
|
||||||
|
assert isinstance(attn_module, Attention)
|
||||||
|
if attn_module.attn_type == AttentionType.DECODER:
|
||||||
|
kv_cache_spec[layer_name] = FullAttentionSpec(
|
||||||
|
block_size=block_size,
|
||||||
|
num_kv_heads=attn_module.num_kv_heads,
|
||||||
|
head_size=attn_module.head_size,
|
||||||
|
dtype=attn_module.dtype,
|
||||||
|
use_mla=use_mla)
|
||||||
|
elif attn_module.attn_type in (AttentionType.ENCODER,
|
||||||
|
AttentionType.ENCODER_ONLY):
|
||||||
|
continue
|
||||||
|
elif attn_module.attn_type == AttentionType.ENCODER_DECODER:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown attention type: {attn_module.attn_type}")
|
||||||
|
return kv_cache_spec
|
||||||
@@ -0,0 +1,202 @@
|
|||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from vllm.utils import logger, sha256
|
||||||
|
from vllm.v1.core.block_pool import BlockPool
|
||||||
|
from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock,
|
||||||
|
PrefixCachingMetrics)
|
||||||
|
from vllm.v1.core.single_type_kv_cache_manager import \
|
||||||
|
get_manager_for_kv_cache_spec
|
||||||
|
from vllm.v1.kv_cache_interface import KVCacheSpec
|
||||||
|
from vllm.v1.metrics.stats import PrefixCacheStats
|
||||||
|
from vllm.v1.request import Request
|
||||||
|
|
||||||
|
|
||||||
|
class CPUCacheStats:
|
||||||
|
|
||||||
|
def __init__(self, enable_prefix_caching: bool, log_stats: bool = False):
|
||||||
|
self.enable_prefix_caching = enable_prefix_caching
|
||||||
|
self.log_stats = log_stats
|
||||||
|
self.prefix_cache_stats = PrefixCacheStats() if log_stats else None
|
||||||
|
self.cpu_prefix_cache_metrics = PrefixCachingMetrics()
|
||||||
|
self.time_sec = int(time.time())
|
||||||
|
|
||||||
|
def log(self):
|
||||||
|
current_time_sec = int(time.time())
|
||||||
|
# Log the prefix cache hit rate every 10 seconds.
|
||||||
|
if current_time_sec - self.time_sec >= 10:
|
||||||
|
self.time_sec = current_time_sec
|
||||||
|
logger.info("CPU Prefix cache hit rate: %.1f%%",
|
||||||
|
self.cpu_prefix_cache_metrics.hit_rate * 100)
|
||||||
|
|
||||||
|
def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]:
|
||||||
|
"""Get (and reset) the prefix cache stats.
|
||||||
|
Returns:
|
||||||
|
The current prefix caching stats, or None if logging is disabled.
|
||||||
|
"""
|
||||||
|
if not self.log_stats:
|
||||||
|
return None
|
||||||
|
stats = self.prefix_cache_stats
|
||||||
|
self.prefix_cache_stats = PrefixCacheStats()
|
||||||
|
return stats
|
||||||
|
|
||||||
|
def update(self, num_tokens, num_computed_tokens):
|
||||||
|
# Note the function is called by scheduler
|
||||||
|
if self.log_stats and self.enable_prefix_caching:
|
||||||
|
assert self.prefix_cache_stats is not None
|
||||||
|
self.prefix_cache_stats.requests += 1
|
||||||
|
self.prefix_cache_stats.queries += num_tokens
|
||||||
|
self.prefix_cache_stats.hits += num_computed_tokens
|
||||||
|
|
||||||
|
def set_cache_stats(self, num_tokens, num_computed_tokens):
|
||||||
|
assert self.prefix_cache_stats is not None
|
||||||
|
self.prefix_cache_stats.hits = num_computed_tokens
|
||||||
|
self.prefix_cache_stats.queries = num_tokens
|
||||||
|
self.prefix_cache_stats.requests = 1
|
||||||
|
|
||||||
|
|
||||||
|
class CPUKVCacheManager:
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
kv_cache_spec: KVCacheSpec,
|
||||||
|
num_cpu_blocks: int,
|
||||||
|
caching_hash_algo: str = "builtin",
|
||||||
|
use_eagle: bool = False,
|
||||||
|
enable_kv_cache_events: bool = False,
|
||||||
|
) -> None:
|
||||||
|
self.block_size = kv_cache_spec.block_size
|
||||||
|
self.num_cpu_blocks = num_cpu_blocks
|
||||||
|
self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash
|
||||||
|
self.use_eagle = use_eagle
|
||||||
|
self.block_pool = BlockPool(self.num_cpu_blocks, True,
|
||||||
|
enable_kv_cache_events)
|
||||||
|
self.single_type_manager = get_manager_for_kv_cache_spec(
|
||||||
|
kv_cache_spec=kv_cache_spec,
|
||||||
|
block_pool=self.block_pool,
|
||||||
|
kv_cache_group_id=0,
|
||||||
|
)
|
||||||
|
# Record kv block hashes, avoid redundant computation.
|
||||||
|
self.req_to_block_hashes: defaultdict[
|
||||||
|
str, list[BlockHash]] = defaultdict(list)
|
||||||
|
# Record blocks touched in get_matched_num_and_touch().
|
||||||
|
self.req_to_computed_blocks: defaultdict[
|
||||||
|
str, list[KVCacheBlock]] = defaultdict(list)
|
||||||
|
# Record the request that failed to allocate.
|
||||||
|
self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool)
|
||||||
|
self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int)
|
||||||
|
self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True,
|
||||||
|
log_stats=True)
|
||||||
|
# Record request that will be free after finish sending
|
||||||
|
self.req_to_free: defaultdict[str, Request] = defaultdict(Request)
|
||||||
|
|
||||||
|
def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]:
|
||||||
|
# When the request requires prompt logprobs, we skip prefix caching.
|
||||||
|
if (request.sampling_params.prompt_logprobs is not None):
|
||||||
|
return 0, False
|
||||||
|
request_id = request.request_id
|
||||||
|
# The block hashes for the request may already be computed
|
||||||
|
# if the scheduler has tried to schedule the request before.
|
||||||
|
block_hashes = self.req_to_block_hashes[request_id]
|
||||||
|
if not block_hashes:
|
||||||
|
block_hashes = request.block_hashes
|
||||||
|
self.req_to_block_hashes[request_id] = block_hashes
|
||||||
|
max_cache_hit_length = request.num_tokens - 1
|
||||||
|
computed_blocks = self.single_type_manager.find_longest_cache_hit(
|
||||||
|
block_hashes=block_hashes,
|
||||||
|
max_length=max_cache_hit_length,
|
||||||
|
kv_cache_group_ids=[0],
|
||||||
|
block_pool=self.block_pool,
|
||||||
|
kv_cache_spec=self.single_type_manager.kv_cache_spec,
|
||||||
|
use_eagle=self.use_eagle,
|
||||||
|
)
|
||||||
|
num_computed_tokens = len(computed_blocks[0]) * self.block_size
|
||||||
|
self.req_to_computed_blocks[request_id] = computed_blocks[0]
|
||||||
|
# We should touch these blocks in the concurrent scenarios.
|
||||||
|
self.block_pool.touch(computed_blocks)
|
||||||
|
|
||||||
|
# cup prefix cache status set and log
|
||||||
|
assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None
|
||||||
|
self.cpu_cache_stats.set_cache_stats(request.num_tokens,
|
||||||
|
num_computed_tokens)
|
||||||
|
self.cpu_cache_stats.cpu_prefix_cache_metrics.observe(
|
||||||
|
self.cpu_cache_stats.prefix_cache_stats)
|
||||||
|
self.cpu_cache_stats.log()
|
||||||
|
|
||||||
|
return num_computed_tokens, False
|
||||||
|
|
||||||
|
def _release_ahead_touch(self, request_id: str):
|
||||||
|
computed_blocks = self.req_to_computed_blocks[request_id]
|
||||||
|
if computed_blocks:
|
||||||
|
self.single_type_manager.block_pool.free_blocks(
|
||||||
|
reversed(computed_blocks))
|
||||||
|
self.req_to_computed_blocks.pop(request_id, None)
|
||||||
|
|
||||||
|
def allocate_slots(self, req_to_num_tokens: dict[str, int],
|
||||||
|
unallocated_req_ids: set[str]) -> dict[str, list[int]]:
|
||||||
|
for request_id in unallocated_req_ids:
|
||||||
|
self._free_slots(request_id)
|
||||||
|
req_to_new_blocks = {}
|
||||||
|
for request_id, num_tokens in req_to_num_tokens.items():
|
||||||
|
if self.req_failed_to_allocate[request_id]:
|
||||||
|
continue
|
||||||
|
new_computed_blocks = self.req_to_computed_blocks[request_id]
|
||||||
|
num_blocks_to_allocate = (
|
||||||
|
self.single_type_manager.get_num_blocks_to_allocate(
|
||||||
|
request_id=request_id,
|
||||||
|
num_tokens=num_tokens,
|
||||||
|
new_computed_blocks=new_computed_blocks,
|
||||||
|
))
|
||||||
|
if num_blocks_to_allocate > self.block_pool.get_num_free_blocks():
|
||||||
|
self._release_ahead_touch(request_id)
|
||||||
|
self.req_failed_to_allocate[request_id] = True
|
||||||
|
continue
|
||||||
|
# Append the new computed blocks to the request blocks until now to
|
||||||
|
# avoid the case where the new blocks cannot be allocated.
|
||||||
|
self.single_type_manager.save_new_computed_blocks(
|
||||||
|
request_id, new_computed_blocks)
|
||||||
|
# Allocate new blocks but do not cache now.
|
||||||
|
new_blocks = self.single_type_manager.allocate_new_blocks(
|
||||||
|
request_id, num_tokens)
|
||||||
|
self.req_to_num_tokens[request_id] = num_tokens
|
||||||
|
# No need to release ref_cnt because we use officially.
|
||||||
|
self.req_to_computed_blocks.pop(request_id, None)
|
||||||
|
req_to_new_blocks[request_id] = [
|
||||||
|
block.block_id for block in new_computed_blocks + new_blocks
|
||||||
|
]
|
||||||
|
return req_to_new_blocks
|
||||||
|
|
||||||
|
def record_request_cache_and_free_slots(self, request: Request):
|
||||||
|
logger.debug(
|
||||||
|
f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager"
|
||||||
|
)
|
||||||
|
self.req_to_free[request.request_id] = request
|
||||||
|
|
||||||
|
def cache_and_free_slots(self, request_id: str):
|
||||||
|
logger.debug(
|
||||||
|
f"Cache and free slots for request {request_id} in cpu_kv_cache_manager"
|
||||||
|
)
|
||||||
|
if request_id not in self.req_to_free:
|
||||||
|
logger.Error(
|
||||||
|
f"request {request_id} not in req_to_free, maybe bug!")
|
||||||
|
return
|
||||||
|
request = self.req_to_free[request_id]
|
||||||
|
if not self.req_failed_to_allocate[request_id]:
|
||||||
|
self.single_type_manager.cache_blocks(
|
||||||
|
request,
|
||||||
|
self.req_to_num_tokens[request_id],
|
||||||
|
)
|
||||||
|
self._free_slots(request_id)
|
||||||
|
logger.debug(
|
||||||
|
f"delete request {request_id} in cpu_kv_cache_manager req_to_free")
|
||||||
|
del self.req_to_free[request_id]
|
||||||
|
|
||||||
|
def _free_slots(self, request_id: str):
|
||||||
|
# This function is designed to be reentrant.
|
||||||
|
self._release_ahead_touch(request_id)
|
||||||
|
self.single_type_manager.free(request_id)
|
||||||
|
self.req_to_block_hashes.pop(request_id, None)
|
||||||
|
self.req_to_computed_blocks.pop(request_id, None)
|
||||||
|
self.req_failed_to_allocate.pop(request_id, None)
|
||||||
|
self.req_to_num_tokens.pop(request_id, None)
|
||||||
269
vllm_ascend/distributed/cpu_offload_manager/metadata.py
Normal file
269
vllm_ascend/distributed/cpu_offload_manager/metadata.py
Normal file
@@ -0,0 +1,269 @@
|
|||||||
|
import math
|
||||||
|
import os
|
||||||
|
import pickle
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from multiprocessing.shared_memory import SharedMemory
|
||||||
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import vllm.envs as envs
|
||||||
|
import zmq
|
||||||
|
from vllm.config import KVTransferConfig, VllmConfig
|
||||||
|
from vllm.utils import get_dtype_size, logger, make_zmq_socket
|
||||||
|
from vllm.v1.kv_cache_interface import AttentionSpec
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \
|
||||||
|
CPUKVCacheManager
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MLAConfig:
|
||||||
|
nope_dim: int
|
||||||
|
rope_dim: int
|
||||||
|
|
||||||
|
|
||||||
|
def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig:
|
||||||
|
if vllm_config.kv_transfer_config is not None:
|
||||||
|
kv_transfer_config = vllm_config.kv_transfer_config
|
||||||
|
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||||
|
return kv_transfer_config
|
||||||
|
elif kv_transfer_config.kv_connector == "MultiConnector":
|
||||||
|
ktcs = kv_transfer_config.kv_connector_extra_config.get(
|
||||||
|
"connectors")
|
||||||
|
for ktc in ktcs:
|
||||||
|
kv_transfer_config = KVTransferConfig(**ktc)
|
||||||
|
if kv_transfer_config.kv_connector == "CPUOffloadingConnector":
|
||||||
|
return kv_transfer_config
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataServer:
|
||||||
|
METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc"
|
||||||
|
DEFAULT_CPU_SWAP_SPACE_GB = 800
|
||||||
|
|
||||||
|
class ZMQRPCClient:
|
||||||
|
|
||||||
|
def __init__(self, identity=f"worker-{os.getpid()}"):
|
||||||
|
logger.info(f"metadata client for worker {identity} started")
|
||||||
|
self.ctx = zmq.Context() # type: ignore
|
||||||
|
self.socket = make_zmq_socket(
|
||||||
|
self.ctx,
|
||||||
|
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||||
|
zmq.DEALER, # type: ignore
|
||||||
|
bind=False,
|
||||||
|
identity=identity.encode(),
|
||||||
|
linger=0)
|
||||||
|
|
||||||
|
def call(self, func_name: str, *args, **kwargs) -> Any:
|
||||||
|
request = (func_name, args, kwargs)
|
||||||
|
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||||
|
self.socket.send(pickle.dumps(request))
|
||||||
|
_ = self.socket.recv()
|
||||||
|
response = pickle.loads(self.socket.recv())
|
||||||
|
result, error = response
|
||||||
|
if error:
|
||||||
|
logger.exception(f"call metadata sever error: {error}")
|
||||||
|
raise error
|
||||||
|
if func_name == "init_cpu_kv_caches":
|
||||||
|
(memory_dict, layer_size, layer_dtype, mla_config) = result
|
||||||
|
# shared_memory_dict is recorded in self to close
|
||||||
|
self.shared_memory_dict = memory_dict
|
||||||
|
result = {}
|
||||||
|
for key, shm in memory_dict.items():
|
||||||
|
tensor = torch.frombuffer(
|
||||||
|
shm.buf, dtype=layer_dtype).reshape(layer_size)
|
||||||
|
if mla_config is not None:
|
||||||
|
tensor = tensor.split(
|
||||||
|
[mla_config.nope_dim, mla_config.rope_dim], dim=-1)
|
||||||
|
result[key] = tensor
|
||||||
|
return result
|
||||||
|
|
||||||
|
def __del__(self):
|
||||||
|
# will be finalized by outer process
|
||||||
|
self.socket.close()
|
||||||
|
self.ctx.term()
|
||||||
|
if hasattr(self, 'shared_memory_dict'):
|
||||||
|
for shm in self.shared_memory_dict.values():
|
||||||
|
shm.close()
|
||||||
|
|
||||||
|
def __init__(self, vllm_config: VllmConfig):
|
||||||
|
self.world_size = vllm_config.parallel_config.world_size
|
||||||
|
self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||||
|
kv_transfer_config = get_cpu_offload_connector(vllm_config)
|
||||||
|
assert kv_transfer_config is not None
|
||||||
|
available_memory_gb = kv_transfer_config.get_from_extra_config(
|
||||||
|
"cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB)
|
||||||
|
self.available_memory = available_memory_gb * 1024 * 1024 * 1024
|
||||||
|
logger.info(f"cpu swap space: {self.available_memory} bytes")
|
||||||
|
self.ctx = zmq.Context() # type: ignore
|
||||||
|
self.socket = make_zmq_socket(
|
||||||
|
self.ctx,
|
||||||
|
MetadataServer.METADATA_SERVER_ADDRESS,
|
||||||
|
zmq.ROUTER, # type: ignore
|
||||||
|
bind=True,
|
||||||
|
linger=0)
|
||||||
|
self.functions: dict[str, Callable] = {
|
||||||
|
"init_cpu_kv_caches": self.init_cpu_kv_caches,
|
||||||
|
"post_init": self.post_init,
|
||||||
|
"ready": self.ready,
|
||||||
|
}
|
||||||
|
self.shared_memory = {} # type: ignore
|
||||||
|
self.num_cpu_blocks = -1
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _safe_create_shared_memory(name: str, size: int) -> SharedMemory:
|
||||||
|
try:
|
||||||
|
existing_shm = SharedMemory(name=name, create=False)
|
||||||
|
existing_shm.close()
|
||||||
|
existing_shm.unlink()
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
return SharedMemory(name=name, create=True, size=size)
|
||||||
|
|
||||||
|
def ready(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def init_cpu_kv_caches(
|
||||||
|
self,
|
||||||
|
pp_rank: int,
|
||||||
|
tp_rank: int,
|
||||||
|
kv_cache_specs: dict[str, AttentionSpec],
|
||||||
|
mla_config: MLAConfig,
|
||||||
|
) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype,
|
||||||
|
MLAConfig]:
|
||||||
|
logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}")
|
||||||
|
# follow the assumption that each layer has the same spec
|
||||||
|
layer = next(iter(kv_cache_specs.values()))
|
||||||
|
assert all([
|
||||||
|
layer.page_size_bytes == any.page_size_bytes
|
||||||
|
for any in kv_cache_specs.values()
|
||||||
|
])
|
||||||
|
# mla shares the same kv cache among different tp
|
||||||
|
if layer.use_mla:
|
||||||
|
tp_rank = 0
|
||||||
|
if (pp_rank, tp_rank) in self.shared_memory:
|
||||||
|
return self.shared_memory[(pp_rank, tp_rank)]
|
||||||
|
available_memory = self.available_memory
|
||||||
|
shared_memory_dict = {}
|
||||||
|
if layer.use_mla:
|
||||||
|
available_memory //= self.pipeline_parallel_size
|
||||||
|
available_memory //= len(kv_cache_specs)
|
||||||
|
num_blocks = available_memory // layer.page_size_bytes
|
||||||
|
layer_size = (num_blocks, layer.block_size, layer.num_kv_heads,
|
||||||
|
layer.head_size) # type: ignore
|
||||||
|
else:
|
||||||
|
available_memory //= self.world_size
|
||||||
|
available_memory //= len(kv_cache_specs)
|
||||||
|
num_blocks = available_memory // layer.page_size_bytes
|
||||||
|
layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads,
|
||||||
|
layer.head_size) # type: ignore
|
||||||
|
nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype)
|
||||||
|
for layer_name in kv_cache_specs.keys():
|
||||||
|
# only this format can share during ZeroMQ+pickle
|
||||||
|
shared_memory_dict[
|
||||||
|
layer_name] = MetadataServer._safe_create_shared_memory(
|
||||||
|
f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes)
|
||||||
|
if layer.use_mla:
|
||||||
|
assert mla_config is not None
|
||||||
|
assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim
|
||||||
|
self.shared_memory[(pp_rank,
|
||||||
|
tp_rank)] = (shared_memory_dict, layer_size,
|
||||||
|
layer.dtype, mla_config)
|
||||||
|
else:
|
||||||
|
self.shared_memory[(pp_rank,
|
||||||
|
tp_rank)] = (shared_memory_dict, layer_size,
|
||||||
|
layer.dtype, None)
|
||||||
|
if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks:
|
||||||
|
self.num_cpu_blocks = num_blocks
|
||||||
|
self.layer = layer
|
||||||
|
return self.shared_memory[(pp_rank, tp_rank)]
|
||||||
|
|
||||||
|
def post_init(self):
|
||||||
|
# different processors in data parallel may call multiple times
|
||||||
|
if hasattr(self, 'cpu_block_manager'):
|
||||||
|
return
|
||||||
|
# do shared_memory() at least once
|
||||||
|
logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}")
|
||||||
|
assert self.num_cpu_blocks >= 0
|
||||||
|
self.cpu_block_manager = CPUKVCacheManager(self.layer,
|
||||||
|
self.num_cpu_blocks)
|
||||||
|
self.functions.update({
|
||||||
|
"get_matched_num_and_touch":
|
||||||
|
self.cpu_block_manager.get_matched_num_and_touch,
|
||||||
|
"allocate_slots":
|
||||||
|
self.cpu_block_manager.allocate_slots,
|
||||||
|
"record_request_cache_and_free_slots":
|
||||||
|
self.cpu_block_manager.record_request_cache_and_free_slots,
|
||||||
|
"cache_and_free_slots":
|
||||||
|
self.cpu_block_manager.cache_and_free_slots,
|
||||||
|
})
|
||||||
|
|
||||||
|
def serve_step(self):
|
||||||
|
client_id = self.socket.recv()
|
||||||
|
_ = self.socket.recv()
|
||||||
|
raw_msg = self.socket.recv()
|
||||||
|
try:
|
||||||
|
func_name, args, kwargs = pickle.loads(raw_msg)
|
||||||
|
except Exception as e:
|
||||||
|
response = (None, Exception(f"Invalid request: {str(e)}"))
|
||||||
|
else:
|
||||||
|
if func_name in self.functions:
|
||||||
|
try:
|
||||||
|
result = self.functions[func_name](*args, **kwargs)
|
||||||
|
response = (result, None) # type: ignore
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"metadata execute error: {e}")
|
||||||
|
response = (None, e) # type: ignore
|
||||||
|
else:
|
||||||
|
response = (None, NameError(f"Function {func_name} not found"))
|
||||||
|
self.socket.send(client_id, zmq.SNDMORE) # type: ignore
|
||||||
|
self.socket.send(b"", zmq.SNDMORE) # type: ignore
|
||||||
|
self.socket.send(pickle.dumps(response))
|
||||||
|
|
||||||
|
def shutdown(self):
|
||||||
|
self.socket.close()
|
||||||
|
self.ctx.term()
|
||||||
|
socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace(
|
||||||
|
"ipc://", "")
|
||||||
|
if os.path.exists(socket_path):
|
||||||
|
os.remove(socket_path)
|
||||||
|
for cached in self.shared_memory.values():
|
||||||
|
for shm in cached[0].values():
|
||||||
|
shm.close()
|
||||||
|
shm.unlink()
|
||||||
|
|
||||||
|
|
||||||
|
class MetadataServerProc:
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def run_metadata_server(vllm_config: VllmConfig):
|
||||||
|
if (not vllm_config.cache_config.enable_prefix_caching
|
||||||
|
or get_cpu_offload_connector(vllm_config) is None):
|
||||||
|
return
|
||||||
|
|
||||||
|
shutdown_requested = False
|
||||||
|
|
||||||
|
def _signal_handler(signum, frame):
|
||||||
|
nonlocal shutdown_requested
|
||||||
|
if not shutdown_requested:
|
||||||
|
shutdown_requested = True
|
||||||
|
raise SystemExit()
|
||||||
|
|
||||||
|
# Either SIGTERM or SIGINT will terminate the worker
|
||||||
|
# signal.signal(signal.SIGTERM, _signal_handler)
|
||||||
|
# signal.signal(signal.SIGINT, _signal_handler)
|
||||||
|
metadata_server: Optional[MetadataServer] = None
|
||||||
|
try:
|
||||||
|
metadata_server = MetadataServer(vllm_config)
|
||||||
|
logger.info("Metadata server started.")
|
||||||
|
while True:
|
||||||
|
metadata_server.serve_step()
|
||||||
|
except SystemExit:
|
||||||
|
logger.info("Metadata server exiting.")
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"Metadata server error: {e}.")
|
||||||
|
raise e
|
||||||
|
finally:
|
||||||
|
if metadata_server is not None:
|
||||||
|
metadata_server.shutdown()
|
||||||
@@ -156,8 +156,9 @@ def mla_forward(
|
|||||||
else:
|
else:
|
||||||
attn_metadata = forward_context.attn_metadata
|
attn_metadata = forward_context.attn_metadata
|
||||||
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine]
|
||||||
self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata,
|
self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states,
|
||||||
need_gather_q_kv, output)
|
kv_cache, attn_metadata, need_gather_q_kv,
|
||||||
|
output)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user