diff --git a/tests/ut/attention/test_mla_v1.py b/tests/ut/attention/test_mla_v1.py index 5fdc202..0164057 100644 --- a/tests/ut/attention/test_mla_v1.py +++ b/tests/ut/attention/test_mla_v1.py @@ -554,7 +554,11 @@ class TestAscendMLAImpl(TestBase): self.impl.num_kv_heads = self.impl.num_heads decode_res, prefill_res = self.impl._mla_preprocess( - hidden_states, kv_cache, attn_metadata, need_gather_q_kv=False) + "mock_layer", + hidden_states, + kv_cache, + attn_metadata, + need_gather_q_kv=False) self.assertIsNotNone(decode_res) self.assertIsNotNone(prefill_res) diff --git a/tests/ut/torchair/models/test_torchair_deepseek_v2.py b/tests/ut/torchair/models/test_torchair_deepseek_v2.py index 3942144..5a7c2a2 100644 --- a/tests/ut/torchair/models/test_torchair_deepseek_v2.py +++ b/tests/ut/torchair/models/test_torchair_deepseek_v2.py @@ -328,4 +328,4 @@ def test_torchair_deepseek_v2_for_causal_lm(mock_distributed, vllm_config): "vllm.model_executor.model_loader.weight_utils.default_weight_loader" ): loaded = model.load_weights(weights) - assert loaded is not None \ No newline at end of file + assert loaded is not None diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index fca31df..511d3ad 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -26,53 +26,21 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, AttentionLayer, AttentionType) from vllm.attention.backends.utils import CommonAttentionState from vllm.config import VllmConfig -from vllm.distributed.kv_transfer import (get_kv_transfer_group, - has_kv_transfer_group, - is_v1_kv_transfer_group) from vllm.forward_context import ForwardContext, get_forward_context from vllm.utils import cdiv, direct_register_custom_op from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.kv_cache_interface import AttentionSpec -from vllm_ascend.attention.utils import AscendCommonAttentionMetadata +from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, + maybe_save_kv_layer_to_connector, + wait_for_kv_layer_from_connector) from vllm_ascend.compilation.acl_graph import get_graph_params from vllm_ascend.ops.attention import vanilla_chunked_prefill from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_NZ, aligned_16, is_310p, nd_to_nz_2d, nd_to_nz_spec) -def wait_for_kv_layer_from_connector(layer_name: str): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - # TODO: assert ascendMetadata - connector.wait_for_layer_load(layer_name) - - -def maybe_save_kv_layer_to_connector( - layer_name: str, - kv_cache_layer: List[torch.Tensor], -): - if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): - return - - connector = get_kv_transfer_group() - - forward_context: ForwardContext = get_forward_context() - attn_metadata = forward_context.attn_metadata - if attn_metadata is None: - return - # TODO: assert ascendMetadata - connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) - - class AscendAttentionBackend(AttentionBackend): accept_output_buffer: bool = True diff --git a/vllm_ascend/attention/mla_v1.py b/vllm_ascend/attention/mla_v1.py index 7a85bbb..cb15bd1 100644 --- a/vllm_ascend/attention/mla_v1.py +++ b/vllm_ascend/attention/mla_v1.py @@ -18,7 +18,9 @@ from vllm.v1.attention.backends.utils import AttentionCGSupport from vllm_ascend.ascend_config import get_ascend_config from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import (AscendCommonAttentionMetadata, - split_decodes_and_prefills) + maybe_save_kv_layer_to_connector, + split_decodes_and_prefills, + wait_for_kv_layer_from_connector) from vllm_ascend.multistream.base import MSAttentionMetadataSplitConfig from vllm_ascend.multistream.context import get_multistream_comm_context from vllm_ascend.multistream.ms_split import model_input_split_v1_mla_attn @@ -858,8 +860,8 @@ class AscendMLAImpl(MLAAttentionImpl): current_ms_metadata.before_comm_event.wait() return self._v_up_proj(attn_output) - def _mla_preprocess(self, hidden_states, kv_cache, attn_metadata, - need_gather_q_kv): + def _mla_preprocess(self, layer_name, hidden_states, kv_cache, + attn_metadata, need_gather_q_kv): # MLA Preprocess: # 1. Perform q_a_proj and q_a_layernorm to obtain q_c # 2. Perform kv_a_proj_with_mqa to obtain kv_no_split @@ -888,6 +890,8 @@ class AscendMLAImpl(MLAAttentionImpl): kv_no_split = get_tp_group().all_gather(kv_no_split, 0) decode_preprocess_res = None prefill_preprocess_res = None + if has_prefill: + wait_for_kv_layer_from_connector(layer_name) # Preprocess for decode tokens if has_decode: decode_q_c = q_c[:num_decode_tokens] @@ -934,6 +938,7 @@ class AscendMLAImpl(MLAAttentionImpl): def forward( self, + layer_name, hidden_states: torch.Tensor, # query in unified attn kv_cache: Tuple[torch.Tensor], attn_metadata: M, @@ -960,7 +965,8 @@ class AscendMLAImpl(MLAAttentionImpl): # MLA Preprocess decode_preprocess_res, prefill_preprocess_res = self._mla_preprocess( - hidden_states, kv_cache, attn_metadata, need_gather_q_kv) + layer_name, hidden_states, kv_cache, attn_metadata, + need_gather_q_kv) if decode_preprocess_res is not None: # MLA Preprocess for decoding @@ -1018,4 +1024,8 @@ class AscendMLAImpl(MLAAttentionImpl): is_force_scatter=self.enable_shared_expert_dp)[0] current_ms_metadata.after_comm_event.record() del o_proj_input + + has_prefill = attn_metadata.num_prefills > 0 + if has_prefill: + maybe_save_kv_layer_to_connector(layer_name, list(kv_cache)) return output_padded diff --git a/vllm_ascend/attention/utils.py b/vllm_ascend/attention/utils.py index 01cf4ea..519cde0 100644 --- a/vllm_ascend/attention/utils.py +++ b/vllm_ascend/attention/utils.py @@ -1,7 +1,11 @@ from dataclasses import dataclass -from typing import Any +from typing import Any, List import torch +from vllm.distributed.kv_transfer import (get_kv_transfer_group, + has_kv_transfer_group, + is_v1_kv_transfer_group) +from vllm.forward_context import ForwardContext, get_forward_context @dataclass @@ -100,3 +104,34 @@ def split_decodes_and_prefills( num_decode_tokens = query_start_loc[first_prefill].item() num_prefill_tokens = num_tokens - num_decode_tokens return (num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens) + + +def wait_for_kv_layer_from_connector(layer_name: str): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.wait_for_layer_load(layer_name) + + +def maybe_save_kv_layer_to_connector( + layer_name: str, + kv_cache_layer: List[torch.Tensor], +): + if not has_kv_transfer_group() or not is_v1_kv_transfer_group(): + return + + connector = get_kv_transfer_group() + + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.attn_metadata + if attn_metadata is None: + return + # TODO: assert ascendMetadata + connector.save_kv_layer(layer_name, kv_cache_layer, attn_metadata) diff --git a/vllm_ascend/distributed/cpu_offload_connector.py b/vllm_ascend/distributed/cpu_offload_connector.py new file mode 100644 index 0000000..b27595d --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_connector.py @@ -0,0 +1,457 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import copy +import queue +import threading +import time +from collections import defaultdict +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Optional, Sequence + +import torch +from vllm.attention import AttentionType +from vllm.attention.layer import Attention +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.distributed.parallel_state import get_pp_group, get_tp_group +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.utils import logger +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.kv_cache_interface import FullAttentionSpec, KVCacheSpec + +from vllm_ascend.distributed.cpu_offload_manager.metadata import ( + MetadataServer, MetadataServerProc, MLAConfig) + +if TYPE_CHECKING: + from vllm.attention.backends.abstract import AttentionMetadata + from vllm.forward_context import ForwardContext + from vllm.v1.core.kv_cache_manager import KVCacheBlocks + from vllm.v1.request import Request + + +@dataclass +class ReqMeta: + gpu_block_ids: list[int] + cpu_block_ids: list[int] + num_scheduled_tokens: int + num_computed_tokens: int + num_gpu_computed_tokens: int + num_cpu_computed_tokens: int + + def update(self, other: "ReqMeta"): + self.gpu_block_ids.extend(other.gpu_block_ids) + self.cpu_block_ids.extend(other.cpu_block_ids) + self.num_scheduled_tokens = other.num_scheduled_tokens + self.num_computed_tokens = other.num_computed_tokens + self.num_gpu_computed_tokens = other.num_gpu_computed_tokens + self.num_cpu_computed_tokens = other.num_cpu_computed_tokens + + +@dataclass +class CPUOffloadingConnectorMetadata(KVConnectorMetadata): + requests: dict[str, ReqMeta] + finished_req_ids: set[str] + + +class CPUOffloadingConnector(KVConnectorBase_V1): + + def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole): + if not vllm_config.cache_config.enable_prefix_caching: + self.connector_scheduler: Optional[ + CPUOffloadingConnectorScheduler] = None + self.connector_worker: Optional[ + CPUOffloadingConnectorWorker] = None + elif role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = CPUOffloadingConnectorScheduler( + vllm_config) + self.connector_worker = None + elif role == KVConnectorRole.WORKER: + self.connector_scheduler = None + self.connector_worker = CPUOffloadingConnectorWorker(vllm_config) + + # ============================== + # Worker-side methods + # ============================== + + def bind_connector_metadata( + self, connector_metadata: KVConnectorMetadata) -> None: + if self.connector_worker is not None: + assert isinstance(connector_metadata, + CPUOffloadingConnectorMetadata) + self.connector_worker.bind_connector_metadata(connector_metadata) + + def clear_connector_metadata(self) -> None: + assert self.connector_worker is not None + self.connector_worker.clear_connector_metadata() + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + if self.connector_worker is not None: + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + if self.connector_worker is not None: + self.connector_worker.start_load_kv() + + def wait_for_layer_load(self, layer_name: str) -> None: + if self.connector_worker is not None: + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + pass + + def wait_for_save(self): + pass + + def get_finished( + self, finished_req_ids: set[str] + ) -> tuple[Optional[set[str]], Optional[set[str]]]: + assert self.connector_worker is not None + return self.connector_worker.get_finished(), None + + # Scheduler-side methods + # ============================== + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + if self.connector_scheduler is not None: + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + return 0, False + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + if self.connector_scheduler is not None: + return self.connector_scheduler.update_state_after_alloc(request) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + if self.connector_scheduler is not None: + return self.connector_scheduler.build_connector_meta( + scheduler_output) + return KVConnectorMetadata() + + def request_finished( + self, request: "Request", + block_ids: list[int]) -> tuple[bool, Optional[dict[str, Any]]]: + if self.connector_scheduler is not None: + self.connector_scheduler.request_finished(request) + return True, None + + +class CPUOffloadingConnectorScheduler: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorScheduler") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.use_mla = vllm_config.model_config.use_mla + self.num_gpu_computed_tokens: dict[str, int] = {} + self.num_cpu_computed_tokens: dict[str, int] = {} + self.allocated_req_ids: set[str] = set() + self.finished_req_ids: list[str] = [] + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.zmq_rpc_client.call("post_init") + if vllm_config.kv_transfer_config is not None: + self.swap_in_threshold = vllm_config.kv_transfer_config.get_from_extra_config( + "swap_in_threshold", 0) + else: + self.swap_in_threshold = 0 + logger.info(f"swap_in_threshold: {self.swap_in_threshold}") + + def get_num_new_matched_tokens( + self, ori_request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + num_cpu_computed_tokens, load_async = self.zmq_rpc_client.call( + "get_matched_num_and_touch", request) + self.num_gpu_computed_tokens[request.request_id] = num_computed_tokens + self.num_cpu_computed_tokens[ + request.request_id] = num_cpu_computed_tokens + if num_cpu_computed_tokens - num_computed_tokens >= self.swap_in_threshold: + return num_cpu_computed_tokens - num_computed_tokens, load_async + else: + return 0, load_async + + def update_state_after_alloc(self, request: "Request"): + self.allocated_req_ids.add(request.request_id) + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + num_tokens = {} + # process scheduled_new_reqs + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + num_tokens[req_id] = ( + req.num_computed_tokens + + scheduler_output.num_scheduled_tokens[req_id]) + + # process scheduled_cached_reqs + cached_reqs = scheduler_output.scheduled_cached_reqs + for idx, req_id in enumerate(cached_reqs.req_ids): + num_tokens[req_id] = ( + cached_reqs.num_computed_tokens[idx] + + scheduler_output.num_scheduled_tokens[req_id]) + + unallocated_req_ids = set(self.num_gpu_computed_tokens.keys() - + self.allocated_req_ids - + scheduler_output.num_scheduled_tokens.keys()) + new_cpu_block_ids = self.zmq_rpc_client.call("allocate_slots", + num_tokens, + unallocated_req_ids) + metadata = CPUOffloadingConnectorMetadata( + requests={}, + finished_req_ids=set(self.finished_req_ids), + ) + for req in scheduler_output.scheduled_new_reqs: + req_id = req.req_id + gpu_block_ids = req.block_ids[0] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=req.num_computed_tokens, + num_gpu_computed_tokens=self.num_gpu_computed_tokens[req_id], + num_cpu_computed_tokens=self.num_cpu_computed_tokens[req_id]) + + for idx, req_id in enumerate(cached_reqs.req_ids): + gpu_block_ids = cached_reqs.new_block_ids[idx] + metadata.requests[req_id] = ReqMeta( + gpu_block_ids=[] if gpu_block_ids is None else gpu_block_ids, + cpu_block_ids=new_cpu_block_ids.get(req_id, []), + num_scheduled_tokens=scheduler_output. + num_scheduled_tokens[req_id], + num_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_gpu_computed_tokens=cached_reqs.num_computed_tokens[idx], + num_cpu_computed_tokens=cached_reqs.num_computed_tokens[idx]) + self.num_gpu_computed_tokens.clear() + self.num_cpu_computed_tokens.clear() + self.allocated_req_ids.clear() + self.finished_req_ids.clear() + return metadata + + def request_finished(self, ori_request: "Request"): + request = copy.deepcopy(ori_request) + request.get_hash_new_full_blocks = None + self.finished_req_ids.append(request.request_id) + # inform metadata server to record request, and free it after finish sending + self.zmq_rpc_client.call("record_request_cache_and_free_slots", + request) + + +class CPUOffloadingConnectorWorker: + + def __init__(self, vllm_config: VllmConfig): + logger.info("init CPUOffloadingConnectorWorker") + self.vllm_config = vllm_config + self.block_size = vllm_config.cache_config.block_size + self.pp_rank = get_pp_group().rank_in_group + self.tp_group = get_tp_group() + self.tp_rank = self.tp_group.rank_in_group + self.tp_world_size = self.tp_group.world_size + self.use_mla = vllm_config.model_config.use_mla + + self.requests: dict[str, ReqMeta] = {} + self.load_stream = torch.npu.Stream() + self.save_stream = torch.npu.Stream() + self.zmq_rpc_client = MetadataServer.ZMQRPCClient() + self.load_block_mapping: list[tuple[int, int]] = [] + self.save_input_queue: queue.Queue[tuple[str, ReqMeta]] = queue.Queue() + self.save_output_queue: queue.Queue[str] = queue.Queue() + self.save_thread = threading.Thread(target=self._save_listener) + self.save_thread.start() + self.done_sending_count: defaultdict[str, int] = defaultdict(int) + + # start metadata server to init cpu_kv_cache_manager and handle rpc requests + # all dp shared the same metadata server, only start the process on data_rank 0 + if vllm_config.parallel_config.data_parallel_rank == 0 and self.tp_rank == 0 and self.pp_rank == 0: + config = VllmConfig() + config.cache_config = vllm_config.cache_config + config.parallel_config = vllm_config.parallel_config + config.kv_transfer_config = vllm_config.kv_transfer_config + self.init_metadata_server(config) + self._wait_for_metadata_process_start() + + def init_metadata_server(self, vllm_config: VllmConfig): + self.metadata_thread = threading.Thread( + target=MetadataServerProc.run_metadata_server, + args=(vllm_config, ), + ) + self.metadata_thread.daemon = True + self.metadata_thread.start() + + def _wait_for_metadata_process_start(self): + # TODO: wait for metadata server to start, add a rpc to check if ready + while True: + try: + if self.zmq_rpc_client.call("ready"): + break + except Exception as e: + logger.info(f"wait for metadata server to start, error: {e}") + time.sleep(1) + + def bind_connector_metadata( + self, connector_metadata: CPUOffloadingConnectorMetadata) -> None: + for req_id, req in connector_metadata.requests.items(): + if req_id in self.requests: + self.requests[req_id].update(req) + req = self.requests[req_id] + else: + self.requests[req_id] = req + for i in range(req.num_gpu_computed_tokens // self.block_size, + req.num_computed_tokens // self.block_size): + self.load_block_mapping.append( + (req.cpu_block_ids[i], req.gpu_block_ids[i])) + for req_id in connector_metadata.finished_req_ids: + if req_id in self.requests: + self.save_input_queue.put((req_id, self.requests[req_id])) + + def clear_connector_metadata(self) -> None: + self.load_block_mapping.clear() + + def register_kv_caches(self, kv_caches: dict[str, Sequence[torch.Tensor]]): + self.gpu_kv_caches = kv_caches + model_config = self.vllm_config.model_config + mla_config: Optional[MLAConfig] = None + if model_config.use_mla: + mla_config = MLAConfig( + model_config.hf_text_config.kv_lora_rank, + model_config.hf_text_config.qk_rope_head_dim) + self.cpu_kv_caches = list( + self.zmq_rpc_client.call( + "init_cpu_kv_caches", + self.pp_rank, + self.tp_rank, + get_kv_cache_spec(self.vllm_config), + mla_config, + ).values()) + + def start_load_kv(self) -> None: + self.current_layer = 0 + self.gpu_kv_caches_load_iter = iter(self.gpu_kv_caches.values()) + self.load_kv_layer(0) + + def wait_for_layer_load(self) -> None: + # TODO: Replace with `torch.npu.current_stream().wait_stream(self.load_stream)` after fixing the bug. + self.load_stream.synchronize() + self.current_layer += 1 + self.load_kv_layer(self.current_layer) + + def load_kv_layer(self, layer: int): + if layer == len(self.gpu_kv_caches): + return + gpu_kv_caches = next(self.gpu_kv_caches_load_iter) + cpu_kv_caches = self.cpu_kv_caches[layer] + with torch.npu.stream(self.load_stream): + for cpu_block_id, gpu_block_id in self.load_block_mapping: + for gpu_layer_part, cpu_layer_part in zip( + gpu_kv_caches, cpu_kv_caches): + gpu_layer_part[gpu_block_id].copy_( + cpu_layer_part[cpu_block_id], non_blocking=True) + + def get_finished(self) -> set[str]: + done_sending: set[str] = set() + while True: + try: + id = self.save_output_queue.get_nowait() + except queue.Empty: + break + done_sending.add(id) + for id in done_sending: + del self.requests[id] + if self.tp_world_size == 1: + return done_sending + if self.tp_rank == 0: + for req_id in done_sending: + self.done_sending_count[req_id] += 1 + other_ranks_finished_ids: list[str] = [] + for i in range(1, self.tp_world_size): + other_ranks_finished_ids.extend( + self.tp_group.recv_object(src=i)) + for req_id in other_ranks_finished_ids: + self.done_sending_count[req_id] += 1 + all_done_sending: set[str] = set() + for req_id in list(self.done_sending_count.keys()): + if self.done_sending_count[req_id] == self.tp_world_size: + del self.done_sending_count[req_id] + all_done_sending.add(req_id) + # release cpu_kv_cache after request sending finished + # to avoid rpc blocking, use thread to call rpc asynchronously + sending_finished_thread = threading.Thread( + target=self._sending_finished, args=(all_done_sending, )) + sending_finished_thread.daemon = True + sending_finished_thread.start() + + return all_done_sending + else: + self.tp_group.send_object(done_sending, dst=0) + return done_sending + + def _sending_finished(self, all_done_sending): + for req_id in all_done_sending: + logger.debug(f"call cache_and_free_slots for req_id: {req_id}") + self.zmq_rpc_client.call("cache_and_free_slots", req_id) + + def _save_listener(self): + save_block_mapping = [] + while True: + req_id, req = self.save_input_queue.get() + for i in range( + req.num_cpu_computed_tokens // self.block_size, + min((req.num_computed_tokens + req.num_scheduled_tokens) // + self.block_size, len(req.cpu_block_ids))): + save_block_mapping.append( + (req.gpu_block_ids[i], req.cpu_block_ids[i])) + with torch.npu.stream(self.save_stream): + # MLA: kv_layer is tuple[tensor, tensor] means (rope, nope). + # non-MLA: kv_layer is list[tensor], typically means [k, v]. + if self.use_mla: + start, step = self.tp_rank, self.tp_world_size + else: + start, step = 0, 1 + for i in range(start, len(save_block_mapping), step): + gpu_block_id, cpu_block_id = save_block_mapping[i] + for cpu_kv_caches, gpu_kv_caches in zip( + self.cpu_kv_caches, self.gpu_kv_caches.values()): + for cpu_layer_part, gpu_layer_part in zip( + cpu_kv_caches, gpu_kv_caches): + cpu_layer_part[cpu_block_id].copy_( + gpu_layer_part[gpu_block_id], + non_blocking=True) + self.save_stream.synchronize() + self.save_output_queue.put(req_id) + save_block_mapping.clear() + + +# Copied from vllm_ascend/worker/model_runner_v1.py. +def get_kv_cache_spec(vllm_config: VllmConfig) -> dict[str, KVCacheSpec]: + forward_ctx = vllm_config.compilation_config.static_forward_context + block_size = vllm_config.cache_config.block_size + use_mla = vllm_config.model_config.use_mla + kv_cache_spec: dict[str, KVCacheSpec] = {} + for layer_name, attn_module in forward_ctx.items(): + if isinstance(attn_module, FusedMoE): + continue + assert isinstance(attn_module, Attention) + if attn_module.attn_type == AttentionType.DECODER: + kv_cache_spec[layer_name] = FullAttentionSpec( + block_size=block_size, + num_kv_heads=attn_module.num_kv_heads, + head_size=attn_module.head_size, + dtype=attn_module.dtype, + use_mla=use_mla) + elif attn_module.attn_type in (AttentionType.ENCODER, + AttentionType.ENCODER_ONLY): + continue + elif attn_module.attn_type == AttentionType.ENCODER_DECODER: + raise NotImplementedError + else: + raise ValueError( + f"Unknown attention type: {attn_module.attn_type}") + return kv_cache_spec diff --git a/vllm_ascend/distributed/cpu_offload_manager/__init__.py b/vllm_ascend/distributed/cpu_offload_manager/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py b/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py new file mode 100644 index 0000000..fd68189 --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_manager/cpu_kv_cache_manager.py @@ -0,0 +1,202 @@ +import time +from collections import defaultdict +from typing import Optional + +from vllm.utils import logger, sha256 +from vllm.v1.core.block_pool import BlockPool +from vllm.v1.core.kv_cache_utils import (BlockHash, KVCacheBlock, + PrefixCachingMetrics) +from vllm.v1.core.single_type_kv_cache_manager import \ + get_manager_for_kv_cache_spec +from vllm.v1.kv_cache_interface import KVCacheSpec +from vllm.v1.metrics.stats import PrefixCacheStats +from vllm.v1.request import Request + + +class CPUCacheStats: + + def __init__(self, enable_prefix_caching: bool, log_stats: bool = False): + self.enable_prefix_caching = enable_prefix_caching + self.log_stats = log_stats + self.prefix_cache_stats = PrefixCacheStats() if log_stats else None + self.cpu_prefix_cache_metrics = PrefixCachingMetrics() + self.time_sec = int(time.time()) + + def log(self): + current_time_sec = int(time.time()) + # Log the prefix cache hit rate every 10 seconds. + if current_time_sec - self.time_sec >= 10: + self.time_sec = current_time_sec + logger.info("CPU Prefix cache hit rate: %.1f%%", + self.cpu_prefix_cache_metrics.hit_rate * 100) + + def make_prefix_cache_stats(self) -> Optional[PrefixCacheStats]: + """Get (and reset) the prefix cache stats. + Returns: + The current prefix caching stats, or None if logging is disabled. + """ + if not self.log_stats: + return None + stats = self.prefix_cache_stats + self.prefix_cache_stats = PrefixCacheStats() + return stats + + def update(self, num_tokens, num_computed_tokens): + # Note the function is called by scheduler + if self.log_stats and self.enable_prefix_caching: + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.requests += 1 + self.prefix_cache_stats.queries += num_tokens + self.prefix_cache_stats.hits += num_computed_tokens + + def set_cache_stats(self, num_tokens, num_computed_tokens): + assert self.prefix_cache_stats is not None + self.prefix_cache_stats.hits = num_computed_tokens + self.prefix_cache_stats.queries = num_tokens + self.prefix_cache_stats.requests = 1 + + +class CPUKVCacheManager: + + def __init__( + self, + kv_cache_spec: KVCacheSpec, + num_cpu_blocks: int, + caching_hash_algo: str = "builtin", + use_eagle: bool = False, + enable_kv_cache_events: bool = False, + ) -> None: + self.block_size = kv_cache_spec.block_size + self.num_cpu_blocks = num_cpu_blocks + self.caching_hash_fn = sha256 if caching_hash_algo == "sha256" else hash + self.use_eagle = use_eagle + self.block_pool = BlockPool(self.num_cpu_blocks, True, + enable_kv_cache_events) + self.single_type_manager = get_manager_for_kv_cache_spec( + kv_cache_spec=kv_cache_spec, + block_pool=self.block_pool, + kv_cache_group_id=0, + ) + # Record kv block hashes, avoid redundant computation. + self.req_to_block_hashes: defaultdict[ + str, list[BlockHash]] = defaultdict(list) + # Record blocks touched in get_matched_num_and_touch(). + self.req_to_computed_blocks: defaultdict[ + str, list[KVCacheBlock]] = defaultdict(list) + # Record the request that failed to allocate. + self.req_failed_to_allocate: defaultdict[str, bool] = defaultdict(bool) + self.req_to_num_tokens: defaultdict[str, int] = defaultdict(int) + self.cpu_cache_stats = CPUCacheStats(enable_prefix_caching=True, + log_stats=True) + # Record request that will be free after finish sending + self.req_to_free: defaultdict[str, Request] = defaultdict(Request) + + def get_matched_num_and_touch(self, request: Request) -> tuple[int, bool]: + # When the request requires prompt logprobs, we skip prefix caching. + if (request.sampling_params.prompt_logprobs is not None): + return 0, False + request_id = request.request_id + # The block hashes for the request may already be computed + # if the scheduler has tried to schedule the request before. + block_hashes = self.req_to_block_hashes[request_id] + if not block_hashes: + block_hashes = request.block_hashes + self.req_to_block_hashes[request_id] = block_hashes + max_cache_hit_length = request.num_tokens - 1 + computed_blocks = self.single_type_manager.find_longest_cache_hit( + block_hashes=block_hashes, + max_length=max_cache_hit_length, + kv_cache_group_ids=[0], + block_pool=self.block_pool, + kv_cache_spec=self.single_type_manager.kv_cache_spec, + use_eagle=self.use_eagle, + ) + num_computed_tokens = len(computed_blocks[0]) * self.block_size + self.req_to_computed_blocks[request_id] = computed_blocks[0] + # We should touch these blocks in the concurrent scenarios. + self.block_pool.touch(computed_blocks) + + # cup prefix cache status set and log + assert self.cpu_cache_stats is not None and self.cpu_cache_stats.prefix_cache_stats is not None + self.cpu_cache_stats.set_cache_stats(request.num_tokens, + num_computed_tokens) + self.cpu_cache_stats.cpu_prefix_cache_metrics.observe( + self.cpu_cache_stats.prefix_cache_stats) + self.cpu_cache_stats.log() + + return num_computed_tokens, False + + def _release_ahead_touch(self, request_id: str): + computed_blocks = self.req_to_computed_blocks[request_id] + if computed_blocks: + self.single_type_manager.block_pool.free_blocks( + reversed(computed_blocks)) + self.req_to_computed_blocks.pop(request_id, None) + + def allocate_slots(self, req_to_num_tokens: dict[str, int], + unallocated_req_ids: set[str]) -> dict[str, list[int]]: + for request_id in unallocated_req_ids: + self._free_slots(request_id) + req_to_new_blocks = {} + for request_id, num_tokens in req_to_num_tokens.items(): + if self.req_failed_to_allocate[request_id]: + continue + new_computed_blocks = self.req_to_computed_blocks[request_id] + num_blocks_to_allocate = ( + self.single_type_manager.get_num_blocks_to_allocate( + request_id=request_id, + num_tokens=num_tokens, + new_computed_blocks=new_computed_blocks, + )) + if num_blocks_to_allocate > self.block_pool.get_num_free_blocks(): + self._release_ahead_touch(request_id) + self.req_failed_to_allocate[request_id] = True + continue + # Append the new computed blocks to the request blocks until now to + # avoid the case where the new blocks cannot be allocated. + self.single_type_manager.save_new_computed_blocks( + request_id, new_computed_blocks) + # Allocate new blocks but do not cache now. + new_blocks = self.single_type_manager.allocate_new_blocks( + request_id, num_tokens) + self.req_to_num_tokens[request_id] = num_tokens + # No need to release ref_cnt because we use officially. + self.req_to_computed_blocks.pop(request_id, None) + req_to_new_blocks[request_id] = [ + block.block_id for block in new_computed_blocks + new_blocks + ] + return req_to_new_blocks + + def record_request_cache_and_free_slots(self, request: Request): + logger.debug( + f"record_request_cache_and_free_slots for request {request.request_id} in cpu_kv_cache_manager" + ) + self.req_to_free[request.request_id] = request + + def cache_and_free_slots(self, request_id: str): + logger.debug( + f"Cache and free slots for request {request_id} in cpu_kv_cache_manager" + ) + if request_id not in self.req_to_free: + logger.Error( + f"request {request_id} not in req_to_free, maybe bug!") + return + request = self.req_to_free[request_id] + if not self.req_failed_to_allocate[request_id]: + self.single_type_manager.cache_blocks( + request, + self.req_to_num_tokens[request_id], + ) + self._free_slots(request_id) + logger.debug( + f"delete request {request_id} in cpu_kv_cache_manager req_to_free") + del self.req_to_free[request_id] + + def _free_slots(self, request_id: str): + # This function is designed to be reentrant. + self._release_ahead_touch(request_id) + self.single_type_manager.free(request_id) + self.req_to_block_hashes.pop(request_id, None) + self.req_to_computed_blocks.pop(request_id, None) + self.req_failed_to_allocate.pop(request_id, None) + self.req_to_num_tokens.pop(request_id, None) diff --git a/vllm_ascend/distributed/cpu_offload_manager/metadata.py b/vllm_ascend/distributed/cpu_offload_manager/metadata.py new file mode 100644 index 0000000..ddfd37c --- /dev/null +++ b/vllm_ascend/distributed/cpu_offload_manager/metadata.py @@ -0,0 +1,269 @@ +import math +import os +import pickle +from dataclasses import dataclass +from multiprocessing.shared_memory import SharedMemory +from typing import Any, Callable, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.config import KVTransferConfig, VllmConfig +from vllm.utils import get_dtype_size, logger, make_zmq_socket +from vllm.v1.kv_cache_interface import AttentionSpec + +from vllm_ascend.distributed.cpu_offload_manager.cpu_kv_cache_manager import \ + CPUKVCacheManager + + +@dataclass +class MLAConfig: + nope_dim: int + rope_dim: int + + +def get_cpu_offload_connector(vllm_config: VllmConfig) -> KVTransferConfig: + if vllm_config.kv_transfer_config is not None: + kv_transfer_config = vllm_config.kv_transfer_config + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + elif kv_transfer_config.kv_connector == "MultiConnector": + ktcs = kv_transfer_config.kv_connector_extra_config.get( + "connectors") + for ktc in ktcs: + kv_transfer_config = KVTransferConfig(**ktc) + if kv_transfer_config.kv_connector == "CPUOffloadingConnector": + return kv_transfer_config + return None + + +class MetadataServer: + METADATA_SERVER_ADDRESS = f"ipc://{envs.VLLM_RPC_BASE_PATH}/metadata.ipc" + DEFAULT_CPU_SWAP_SPACE_GB = 800 + + class ZMQRPCClient: + + def __init__(self, identity=f"worker-{os.getpid()}"): + logger.info(f"metadata client for worker {identity} started") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.DEALER, # type: ignore + bind=False, + identity=identity.encode(), + linger=0) + + def call(self, func_name: str, *args, **kwargs) -> Any: + request = (func_name, args, kwargs) + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(request)) + _ = self.socket.recv() + response = pickle.loads(self.socket.recv()) + result, error = response + if error: + logger.exception(f"call metadata sever error: {error}") + raise error + if func_name == "init_cpu_kv_caches": + (memory_dict, layer_size, layer_dtype, mla_config) = result + # shared_memory_dict is recorded in self to close + self.shared_memory_dict = memory_dict + result = {} + for key, shm in memory_dict.items(): + tensor = torch.frombuffer( + shm.buf, dtype=layer_dtype).reshape(layer_size) + if mla_config is not None: + tensor = tensor.split( + [mla_config.nope_dim, mla_config.rope_dim], dim=-1) + result[key] = tensor + return result + + def __del__(self): + # will be finalized by outer process + self.socket.close() + self.ctx.term() + if hasattr(self, 'shared_memory_dict'): + for shm in self.shared_memory_dict.values(): + shm.close() + + def __init__(self, vllm_config: VllmConfig): + self.world_size = vllm_config.parallel_config.world_size + self.pipeline_parallel_size = vllm_config.parallel_config.pipeline_parallel_size + kv_transfer_config = get_cpu_offload_connector(vllm_config) + assert kv_transfer_config is not None + available_memory_gb = kv_transfer_config.get_from_extra_config( + "cpu_swap_space_gb", MetadataServer.DEFAULT_CPU_SWAP_SPACE_GB) + self.available_memory = available_memory_gb * 1024 * 1024 * 1024 + logger.info(f"cpu swap space: {self.available_memory} bytes") + self.ctx = zmq.Context() # type: ignore + self.socket = make_zmq_socket( + self.ctx, + MetadataServer.METADATA_SERVER_ADDRESS, + zmq.ROUTER, # type: ignore + bind=True, + linger=0) + self.functions: dict[str, Callable] = { + "init_cpu_kv_caches": self.init_cpu_kv_caches, + "post_init": self.post_init, + "ready": self.ready, + } + self.shared_memory = {} # type: ignore + self.num_cpu_blocks = -1 + + @staticmethod + def _safe_create_shared_memory(name: str, size: int) -> SharedMemory: + try: + existing_shm = SharedMemory(name=name, create=False) + existing_shm.close() + existing_shm.unlink() + except FileNotFoundError: + pass + return SharedMemory(name=name, create=True, size=size) + + def ready(self): + return True + + def init_cpu_kv_caches( + self, + pp_rank: int, + tp_rank: int, + kv_cache_specs: dict[str, AttentionSpec], + mla_config: MLAConfig, + ) -> tuple[dict[str, SharedMemory], tuple[int, ...], torch.dtype, + MLAConfig]: + logger.info(f"receive pp rank: {pp_rank}, tp rank: {tp_rank}") + # follow the assumption that each layer has the same spec + layer = next(iter(kv_cache_specs.values())) + assert all([ + layer.page_size_bytes == any.page_size_bytes + for any in kv_cache_specs.values() + ]) + # mla shares the same kv cache among different tp + if layer.use_mla: + tp_rank = 0 + if (pp_rank, tp_rank) in self.shared_memory: + return self.shared_memory[(pp_rank, tp_rank)] + available_memory = self.available_memory + shared_memory_dict = {} + if layer.use_mla: + available_memory //= self.pipeline_parallel_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + else: + available_memory //= self.world_size + available_memory //= len(kv_cache_specs) + num_blocks = available_memory // layer.page_size_bytes + layer_size = (2, num_blocks, layer.block_size, layer.num_kv_heads, + layer.head_size) # type: ignore + nbytes = math.prod(layer_size) * get_dtype_size(layer.dtype) + for layer_name in kv_cache_specs.keys(): + # only this format can share during ZeroMQ+pickle + shared_memory_dict[ + layer_name] = MetadataServer._safe_create_shared_memory( + f"cpu_kv_cache_{pp_rank}_{tp_rank}_{layer_name}", nbytes) + if layer.use_mla: + assert mla_config is not None + assert layer.head_size == mla_config.rope_dim + mla_config.nope_dim + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, mla_config) + else: + self.shared_memory[(pp_rank, + tp_rank)] = (shared_memory_dict, layer_size, + layer.dtype, None) + if self.num_cpu_blocks == -1 or num_blocks < self.num_cpu_blocks: + self.num_cpu_blocks = num_blocks + self.layer = layer + return self.shared_memory[(pp_rank, tp_rank)] + + def post_init(self): + # different processors in data parallel may call multiple times + if hasattr(self, 'cpu_block_manager'): + return + # do shared_memory() at least once + logger.info(f"assign cpu num blocks: {self.num_cpu_blocks}") + assert self.num_cpu_blocks >= 0 + self.cpu_block_manager = CPUKVCacheManager(self.layer, + self.num_cpu_blocks) + self.functions.update({ + "get_matched_num_and_touch": + self.cpu_block_manager.get_matched_num_and_touch, + "allocate_slots": + self.cpu_block_manager.allocate_slots, + "record_request_cache_and_free_slots": + self.cpu_block_manager.record_request_cache_and_free_slots, + "cache_and_free_slots": + self.cpu_block_manager.cache_and_free_slots, + }) + + def serve_step(self): + client_id = self.socket.recv() + _ = self.socket.recv() + raw_msg = self.socket.recv() + try: + func_name, args, kwargs = pickle.loads(raw_msg) + except Exception as e: + response = (None, Exception(f"Invalid request: {str(e)}")) + else: + if func_name in self.functions: + try: + result = self.functions[func_name](*args, **kwargs) + response = (result, None) # type: ignore + except Exception as e: + logger.exception(f"metadata execute error: {e}") + response = (None, e) # type: ignore + else: + response = (None, NameError(f"Function {func_name} not found")) + self.socket.send(client_id, zmq.SNDMORE) # type: ignore + self.socket.send(b"", zmq.SNDMORE) # type: ignore + self.socket.send(pickle.dumps(response)) + + def shutdown(self): + self.socket.close() + self.ctx.term() + socket_path = MetadataServer.METADATA_SERVER_ADDRESS.replace( + "ipc://", "") + if os.path.exists(socket_path): + os.remove(socket_path) + for cached in self.shared_memory.values(): + for shm in cached[0].values(): + shm.close() + shm.unlink() + + +class MetadataServerProc: + + @staticmethod + def run_metadata_server(vllm_config: VllmConfig): + if (not vllm_config.cache_config.enable_prefix_caching + or get_cpu_offload_connector(vllm_config) is None): + return + + shutdown_requested = False + + def _signal_handler(signum, frame): + nonlocal shutdown_requested + if not shutdown_requested: + shutdown_requested = True + raise SystemExit() + + # Either SIGTERM or SIGINT will terminate the worker + # signal.signal(signal.SIGTERM, _signal_handler) + # signal.signal(signal.SIGINT, _signal_handler) + metadata_server: Optional[MetadataServer] = None + try: + metadata_server = MetadataServer(vllm_config) + logger.info("Metadata server started.") + while True: + metadata_server.serve_step() + except SystemExit: + logger.info("Metadata server exiting.") + raise + except Exception as e: + logger.exception(f"Metadata server error: {e}.") + raise e + finally: + if metadata_server is not None: + metadata_server.shutdown() diff --git a/vllm_ascend/models/layers/mla.py b/vllm_ascend/models/layers/mla.py index fa5317c..57c91bd 100644 --- a/vllm_ascend/models/layers/mla.py +++ b/vllm_ascend/models/layers/mla.py @@ -156,8 +156,9 @@ def mla_forward( else: attn_metadata = forward_context.attn_metadata kv_cache = self.mla_attn.kv_cache[forward_context.virtual_engine] - self.mla_attn.impl.forward(hidden_states, kv_cache, attn_metadata, - need_gather_q_kv, output) + self.mla_attn.impl.forward(self.mla_attn.layer_name, hidden_states, + kv_cache, attn_metadata, need_gather_q_kv, + output) return