diff --git a/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md new file mode 100644 index 0000000..b91705a --- /dev/null +++ b/examples/disaggregated_prefill_v1/mooncake_connector_store_deployment_guide.md @@ -0,0 +1,266 @@ +# Mooncacke Store Deployment Guide + +## Environmental Dependencies + +* Software: + * Python >= 3.9, < 3.12 + * CANN >= 8.2.rc1 + * PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724 + * vLLM:main branch + * vLLM-Ascend:main branch + * Mooncake:[AscendTransport/Mooncake at pooling-async-memcpy](https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy)(Currently available branch code, continuously updated.) + Installation and Compilation Guide:https://github.com/AscendTransport/Mooncake/tree/pooling-async-memcpy?tab=readme-ov-file#build-and-use-binaries + +## run mooncake master + +### 1.Configure mooncake.json + +The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located. + +``` +{ + "local_hostname": "xx.xx.xx.xx", + "metadata_server": "P2PHANDSHAKE", + "protocol": "ascend", + "device_name": "", + "master_server_address": "xx.xx.xx.xx:50088", + "global_segment_size": 30000000000 +} +``` + +**local_hostname**: Configured as the IP address of the current master node, +**metadata_server**: Configured as **P2PHANDSHAKE**, +**protocol:** Configured for Ascend to use Mooncake's HCCL communication, +**device_name**: "" +**master_server_address**: Configured with the IP and port of the master service +**global_segment_size**: Expands the kvcache size registered by the PD node to the master + +### 2. Start mooncake_master + +Under the mooncake folder: + +``` +mooncake_master --port 50088 +``` + +## Pooling and Prefill Decode Disaggregate Scenario + +### 1.Run `prefill` Node and `decode` Node + +Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache. + +`prefill` Node: + +``` +bash multi_producer.sh +``` + +The content of the multi_producer.sh script: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_producer", + "kv_port": "20001", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_producer", + } + ] + } +}' > p.log 2>&1 +``` + +`decode` Node: + +``` +bash multi_consumer.sh +``` + +The content of multi_consumer.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=4,5,6,7 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8200 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MultiConnector", + "kv_role": "kv_consumer", + "kv_connector_extra_config": { + "use_layerwise": false, + "connectors": [ + { + "kv_connector": "MooncakeConnectorV1", + "kv_role": "kv_consumer", + "kv_port": "20002", + "kv_connector_extra_config": { + "prefill": { + "dp_size": 1, + "tp_size": 1 + }, + "decode": { + "dp_size": 1, + "tp_size": 1 + } + } + }, + { + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_consumer", + } + ] + } + }' > d.log 2>&1 +``` + +### 2、Start proxy_server. + +``` +bash proxy.sh +``` + +proxy.sh content: +Change localhost to your actual IP address. + +``` +python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_example.py \ + --host localhost\ + --prefiller-hosts localhost \ + --prefiller-ports 8100 \ + --decoder-hosts localhost\ + --decoder-ports 8200 \ +``` + +### 3. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. + +Short question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` + +## Pooling and Mixed Deployment Scenario + +### 1、Run Mixed Department Script + +The mixed script is essentially a pure pooling scenario for the P node. + +``` +bash mixed_department.sh +``` + +Content of mixed_department.sh: + +``` +export LD_LIBRARY_PATH=/usr/local/Ascend/ascend-toolkit/latest/python/site-packages:$LD_LIBRARY_PATH +export PYTHONPATH=$PYTHONPATH:/xxxxx/vllm +export MOONCAKE_CONFIG_PATH="/xxxxxx/mooncake.json" +export VLLM_USE_V1=1 +export ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 +export ASCEND_TRANSPORT_PRINT=1 +# The upper boundary environment variable for memory swap logging is set to mooncake, where 1 indicates enabled and 0 indicates disabled. +export ASCEND_AGGREGATE_ENABLE=1 +# The upper-level environment variable is the switch for enabling the mooncake aggregation function, where 1 means on and 0 means off. + +python3 -m vllm.entrypoints.openai.api_server \ + --model /xxxxx/Qwen2.5-7B-Instruct \ + --port 8100 \ + --trust-remote-code \ + --enforce-eager \ + --no_enable_prefix_caching \ + --tensor-parallel-size 1 \ + --data-parallel-size 1 \ + --max-model-len 10000 \ + --block-size 128 \ + --max-num-batched-tokens 4096 \ + --kv-transfer-config \ + '{ + "kv_connector": "MooncakeConnectorStoreV1", + "kv_role": "kv_producer", + "kv_connector_extra_config": { + "use_layerwise": false + } +}' > mix.log 2>&1 +``` + +### 2. Run Inference + +Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy. + +Short question: + +``` +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Hello. I have a question. The president of the United States is", "max_tokens": 200, "temperature":0.0 }' +``` + +Long question: + +``` +curl -s http://localhost:8100/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }' +``` \ No newline at end of file diff --git a/vllm_ascend/distributed/__init__.py b/vllm_ascend/distributed/__init__.py index 458b814..26ddd8f 100644 --- a/vllm_ascend/distributed/__init__.py +++ b/vllm_ascend/distributed/__init__.py @@ -26,3 +26,8 @@ KVConnectorFactory.register_connector( KVConnectorFactory.register_connector( "MooncakeConnectorV1", "vllm_ascend.distributed.mooncake_connector", "MooncakeConnector") + +KVConnectorFactory.register_connector( + "MooncakeConnectorStoreV1", + "vllm_ascend.distributed.mooncake.mooncake_store_connector_v1", + "MooncakeConnectorV1") diff --git a/vllm_ascend/distributed/mooncake/__init__.py b/vllm_ascend/distributed/mooncake/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/vllm_ascend/distributed/mooncake/config_data.py b/vllm_ascend/distributed/mooncake/config_data.py new file mode 100644 index 0000000..abb3c9e --- /dev/null +++ b/vllm_ascend/distributed/mooncake/config_data.py @@ -0,0 +1,447 @@ +import array +import hashlib +import json +import os +from dataclasses import dataclass +from typing import Iterable, List, Optional, Tuple, Union + +import torch +from vllm.distributed.kv_transfer.kv_connector.v1.base import \ + KVConnectorMetadata +from vllm.utils import cdiv, logger +from vllm.v1.core.sched.output import NewRequestData + + +@dataclass +class MooncakeEngineMetadata: + """name of the LLM model""" + + model_name: str + """ world size when running under a distributed setting """ + world_size: int + """ worker id when running under a distributed setting """ + worker_id: int + """ the format of kv tensors """ + kv_dtype: torch.dtype + """ the shape of kv tensors """ + """ (num_layer, 2, metadata.block_size, num_kv_head, head_size) """ + kv_shape: tuple[int, int, int, int, int] + block_size: int = 128 + """ whether use MLA""" + use_mla: bool = False + + +@dataclass(order=True) +class MooncakeEngineKey: + model_name: str + world_size: int + worker_id: int + chunk_hash: str + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}") + + def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]: + """Split the key into multiple keys for each layer""" + keys = [] + for layer_id in range(num_layers): + keys.append( + LayerMooncakeEngineKey( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + layer_id, + )) + return keys + + def to_dict(self): + # Note(Kuntai): this is used for serializing CacheEngineKey via msgpack. + return { + "__type__": "CacheEngineKey", + "model_name": self.model_name, + "world_size": self.world_size, + "worker_id": self.worker_id, + "chunk_hash": self.chunk_hash, + } + + @staticmethod + def from_dict(d): + return MooncakeEngineKey( + model_name=d["model_name"], + world_size=d["world_size"], + worker_id=d["worker_id"], + chunk_hash=d["chunk_hash"], + ) + + +@dataclass(order=True) +class LayerMooncakeEngineKey(MooncakeEngineKey): + """A key for the layer cache engine""" + + layer_id: int + + def __hash__(self): + return hash(( + self.model_name, + self.world_size, + self.worker_id, + self.chunk_hash, + self.layer_id, + )) + + def to_string(self): + return (f"{self.model_name}@{self.world_size}" + f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}") + + +class ChunkedTokenDatabase(): + + def __init__( + self, + metadata: MooncakeEngineMetadata, + ): + self.metadata = metadata + + def _make_key_by_hash(self, + chunk_hash: str, + layer_id: Optional[int] = None): + assert self.metadata is not None + return MooncakeEngineKey( + self.metadata.model_name, + self.metadata.world_size, + self.metadata.worker_id, + chunk_hash, + ) + + def _hash( + self, + tokens: Union[torch.Tensor, List[int]], + prefix_hash: str, + ) -> str: + # TODO: change it to a more efficient hash function + if isinstance(tokens, torch.Tensor): + tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes() + elif isinstance(tokens, list): + tokens_bytes = array.array("I", tokens).tobytes() + return hashlib.sha256(prefix_hash.encode("ascii") + + tokens_bytes).hexdigest() + + def _chunk_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + ) -> Iterable[Union[torch.Tensor, List[int]]]: + """ + Chunk the tokens into chunks of size self.metadata.block_size. + + :param tokens: the input tokens, with shape [seq_len] + device: the target device after chunking + + :return: a generator of chunks of tokens, each with + shape [metadata.block_size] + """ + for i in range(0, len(tokens), self.metadata.block_size): + yield tokens[i:i + self.metadata.block_size] + + def _prefix_hash( + self, + token_chunks: Iterable[Union[torch.Tensor, List[int]]], + ) -> Iterable[str]: + prefix_hash = '' + for token_chunk in token_chunks: + prefix_hash = self._hash(token_chunk, prefix_hash) + yield prefix_hash + + def process_tokens( + self, + tokens: Union[torch.Tensor, List[int]], + mask: Optional[torch.Tensor] = None, + ) -> Iterable[Tuple[int, int, MooncakeEngineKey]]: + """Process the tokens and return the corresponding cache engine keys. + + :param Union[torch.Tensor, List[int]] tokens: The tokens to process. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched, + and the Falses will ALWAYS be at the PREFIX of the tensor. + + :param bool make_key: Whether to make the cache engine key or not. + If False, the hash value will be returned instead. + + :returns: A iterable of tuples with three elements. The first element + is the start index of the tokens for the key. The second element + is the end index of the tokens for the key. The third element is + the cache engine key (or hash) for the tokens. + + :raises: ValueError if the number of Falses in the mask is not a + multiple of the chunk size. + """ + if mask is not None: + num_falses = mask.numel() - mask.long().sum().item() + else: + num_falses = 0 + + if num_falses % self.metadata.block_size != 0: + raise ValueError( + "The number of Falses in the mask is not a multiple of the chunk size." + ) + total_len = len(tokens) + + token_chunks = self._chunk_tokens(tokens) + prefix_hashes = self._prefix_hash(token_chunks) + + start_idx = 0 + for chunk_id, hash_val in enumerate(prefix_hashes): + start_idx = chunk_id * self.metadata.block_size + end_idx = min(start_idx + self.metadata.block_size, total_len) + if start_idx < num_falses: + continue + else: + yield start_idx, end_idx, self._make_key_by_hash(hash_val) + + +@dataclass +class LoadSpec: + # Number of tokens cached in vLLM + vllm_cached_tokens: int + # Number of tokens that are cached in mooncake + mooncake_cached_tokens: int + # Whether the scheduler allow us to load the tokens + can_load: bool + + +@dataclass +class SaveSpec: + # Skip already saved tokens + skip_leading_tokens: int + # Whether the scheduler allow us to save the tokens + can_save: bool + + +@dataclass +class RequestTracker: + # Request id + req_id: str + + # The token ids that has been scheduled so far + token_ids: list[int] + + # The block ids that has been allocated so far + # NOTE: allocated blocks could be more than the number of tokens + # FIXME: need to check whether the block ids will be changed after + # preemption + allocated_block_ids: list[int] + + # The number of tokens that has been savd + num_saved_tokens: int = 0 + + @staticmethod + def from_new_request( + new_request: "NewRequestData", + num_tokens_to_compute: int, + ) -> "RequestTracker": + """Create the request tracker from a new request. + + Args: + new_request (NewRequestData): the new request data. + num_tokens_to_compute (int): the number of tokens that will + be 'computed', including the `num_computed_tokens` (vLLM's + local cache hit) and new tokens that will be scheduled. + + """ + # vLLM 0.9.0 update: request.block_ids changed from list[int] to + # list[list[int]] + # Need to check the type of request.block_ids + + unfolded_block_ids = [] + + if not isinstance(new_request.block_ids[0], list): + unfolded_block_ids = new_request.block_ids.copy() + else: + unfolded_block_ids = new_request.block_ids[0].copy() + + return RequestTracker( + req_id=new_request.req_id, + token_ids=new_request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=unfolded_block_ids, + num_saved_tokens=0, + ) + + def update( + self, + new_token_ids: list[int], + new_block_ids: Union[tuple[list[int], ...], list[int]], + ) -> None: + """Update the request tracker when a running request is + scheduled again + """ + + self.token_ids.extend(new_token_ids) + + if len(new_block_ids) == 0: + new_block_ids = [] + elif isinstance(new_block_ids, tuple): + new_block_ids = new_block_ids[0] + elif isinstance(new_block_ids, list): + pass + else: + raise ValueError( + f"Unsupported new_block_ids type {type(new_block_ids)}") + self.allocated_block_ids.extend(new_block_ids) + + +@dataclass +class ReqMeta: + # Request id + req_id: str + # Request tokens + token_ids: torch.Tensor + + block_ids: list[int] + # # Slot mapping if exchange for block_id + # slot_mapping: torch.Tensor + # Skip save or not + save_spec: Optional[SaveSpec] = None + # load_spec + load_spec: Optional[LoadSpec] = None + + is_last_chunk: Optional[bool] = None + + @staticmethod + def from_request_tracker( + tracker: RequestTracker, + block_size: int, + load_spec: Optional[LoadSpec] = None, + skip_save: Optional[bool] = False, + is_last_chunk: Optional[bool] = None, + discard_partial_chunks: bool = True, + ) -> Optional["ReqMeta"]: + """Create the request metadata from a request tracker. + + Args: + tracker (RequestTracker): the request tracker. + block_size (int): the block size in vLLM. + load_spec (Optional[LoadSpec]): the load spec for KV cache loading. + skip_save (bool): whether to skip the save operation. + discard_partial_chunks (bool): whether to discard partial chunks. + + Returns: + the request metadata if we need to perform load/save + operations, None otherwise. + """ + input_token_ids = tracker.token_ids + input_token_len = len(input_token_ids) + + # For save operation: do not save if the following condition is met + # 1. has already been saved before (num_saved_tokens > 0) + # 2. number of unsaved tokens is not reached the chunk boundary + skip_leading_tokens = tracker.num_saved_tokens + chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) * + block_size if discard_partial_chunks else 0) + # Calculate number of tokens to save based on discard_partial_chunks + # setting + num_tokens_to_save = ((input_token_len // block_size * block_size) + if discard_partial_chunks else input_token_len) + + skip_save = skip_save or num_tokens_to_save < chunk_boundary + if skip_save and load_spec is None: + return None + + # If we need to save, update the number of saved tokens + if not skip_save: + tracker.num_saved_tokens = num_tokens_to_save + save_spec = SaveSpec(skip_leading_tokens, not skip_save) + + # Calculate the token ids and slot mappings for load and save + # OPTIMIZATION: pre-allocate the buffer for token ids and block ids + token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save] + + # # For load operation: check whether the request is scheduled to load + if load_spec is not None and load_spec.can_load: + logger.debug( + "Scheduled to load %d tokens for request %s", + load_spec.mooncake_cached_tokens, + tracker.req_id, + ) + else: + # Do not load if not in `can_load` state + load_spec = None + logger.debug( + f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}" + ) + return ReqMeta( + req_id=tracker.req_id, + token_ids=token_ids, + block_ids=tracker.allocated_block_ids, + save_spec=save_spec, + load_spec=load_spec, + is_last_chunk=is_last_chunk, + ) + + +class MooncakeConnectorMetadata(KVConnectorMetadata): + + def __init__(self, unfinished_request_ids): + self.requests = [] + self.unfinished_request_ids = unfinished_request_ids + + def add_request(self, req_meta: ReqMeta) -> None: + """Add a request to the metadata. + + Args: + req_meta (ReqMeta): the request metadata. + """ + self.requests.append(req_meta) + + +@dataclass +class LasyerMultiBlockReqMeta: + req_id: str + keys: List[LayerMooncakeEngineKey] + starts: List[int] + ends: list[int] + block_ids: list[int] + layer_id: int + + +@dataclass +class MooncakeStoreConfig: + local_hostname: str + metadata_server: str + global_segment_size: int + local_buffer_size: int + protocol: str + device_name: str + master_server_address: str + + @staticmethod + def from_file(file_path: str) -> "MooncakeStoreConfig": + with open(file_path) as file: + config = json.load(file) + return MooncakeStoreConfig( + local_hostname=config.get("local_hostname"), + metadata_server=config.get("metadata_server"), + global_segment_size=config.get("global_segment_size", 3355443200), + local_buffer_size=config.get("local_buffer_size", 1073741824), + protocol=config.get("protocol", "tcp"), + device_name=config.get("device_name", ""), + master_server_address=config.get("master_server_address")) + + @staticmethod + def load_from_env() -> "MooncakeStoreConfig": + config_path = os.getenv("MOONCAKE_CONFIG_PATH") + if not config_path: + raise ValueError( + "The environment variable 'MOONCAKE_CONFIG_PATH' is not set.") + return MooncakeStoreConfig.from_file(config_path) \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/kv_transfer.py b/vllm_ascend/distributed/mooncake/kv_transfer.py new file mode 100644 index 0000000..dee5101 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/kv_transfer.py @@ -0,0 +1,251 @@ +import queue +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Optional + +import torch +from vllm.utils import logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta) +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + + +class KVTransferThread(threading.Thread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, name: str): + super().__init__(daemon=True, name=name) + self.tp_rank = tp_rank + self.tp_size = tp_size + self.m_store = m_store + self.ready_event = ready_event + self.kv_caches_base_addr = local_kv_caches_base_addr + self.block_len = block_len + self.token_database = token_database + self.block_size = block_size + self.done_task_lock = threading.Lock() + # TODO(jianzs): find a better way to detect MLA. + self.use_mla = len(block_len) == 2 + + self.request_queue: queue.Queue[Any] = queue.Queue() + # TODO(jianzs): make this configurable + self.executor = ThreadPoolExecutor(max_workers=32) + self.finished_requests: set[str] = set() + + def prepare_value(self, start: int, end: int, block_ids: list[int]): + addr_list = [] + size_list = [] + block_id = block_ids[start // self.block_size] + for index, base_addr in enumerate(self.kv_caches_base_addr): + block_len = (self.block_len[index % 2] + if self.use_mla else self.block_len[0]) + + addr = base_addr + block_id * block_len + length = int(block_len / self.block_size * (end - start)) + addr_list.append(addr) + size_list.append(length) + return addr_list, size_list, block_id + + def prepare_value_layer(self, start: int, end: int, block_ids: list[int], + layer_id: int): + block_id = block_ids[start // self.block_size] + if self.use_mla: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[1] + length_k = int(self.block_len[0] / self.block_size * (end - start)) + length_v = int(self.block_len[1] / self.block_size * (end - start)) + size_list = [length_k, length_v] + else: + addr_k = self.kv_caches_base_addr[layer_id * + 2] + block_id * self.block_len[0] + addr_v = self.kv_caches_base_addr[layer_id * 2 + + 1] + block_id * self.block_len[0] + length = int(self.block_len[0] / self.block_size * (end - start)) + size_list = [length, length] + addr_list = [addr_k, addr_v] + return addr_list, size_list + + def add_request( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + is_last_chunk: Optional[bool] = None, + ) -> torch.Tensor: + req = ({ + "req_id": req_id, + "tokens": tokens, + "block_ids": block_ids, + "mask": mask, + "is_last_chunk": is_last_chunk, + }) + self.request_queue.put(req) + + def get_and_clear_finished_requests(self) -> set[str]: + """ + Get and clear the requests that have been completed. + Returns: + A set of request IDs that have been completed. + """ + with self.done_task_lock: + finished_requests = self.finished_requests.copy() + self.finished_requests.clear() + return finished_requests + + def set_finished_request(self, req_id): + with self.done_task_lock: + self.finished_requests.add(req_id) + + def run(self): + """Run the thread to handle KV cache transfer requests.""" + self.ready_event.set() + while True: + try: + request_data = self.request_queue.get() + if request_data is None: + logger.warning("Received a None request!") + self.request_queue.task_done() + continue + self._handle_request(request_data) + except Exception as e: + logger.error(f"Error in KVCacheTransferThread: {e}") + + def _handle_request(self, req_meta: dict[str, Any]): + pass + + +class KVCacheStoreSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheSendingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + is_last_chunk = req_meta["is_last_chunk"] + torch.npu.current_stream().synchronize() + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.put(key, addr, size) + if is_last_chunk: + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreRecvingThread") + + def _handle_request(self, req_meta: dict[str, Any]): + tokens = req_meta["tokens"] + mask = req_meta["mask"] + block_ids = req_meta["block_ids"] + req_id = req_meta["req_id"] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + addr, size, _ = self.prepare_value(start, end, block_ids) + self.m_store.get(key, addr, size) + self.set_finished_request(req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerSendingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + num_layers: int): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerSendingThread") + self.final_layer_id = num_layers - 1 + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + torch.npu.current_stream().synchronize() + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.put(key, addr, size) + if req_meta.layer_id == self.final_layer_id: + self.set_finished_request(req_meta.req_id) + self.request_queue.task_done() + + +class KVCacheStoreLayerRecvingThread(KVTransferThread): + + def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore, + local_kv_caches_base_addr: list[int], + token_database: ChunkedTokenDatabase, block_len: list[int], + block_size: int, ready_event: threading.Event, + get_event: threading.Event): + super().__init__(tp_rank, + tp_size, + m_store, + local_kv_caches_base_addr, + token_database, + block_len, + block_size, + ready_event, + name="KVCacheStoreLayerRecvingThread") + self.get_event = get_event + + def add_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor: + self.request_queue.put(req_meta) + + def _handle_request( # type: ignore[override] + self, req_meta: LasyerMultiBlockReqMeta): + for index, key in enumerate(req_meta.keys): + addr, size = self.prepare_value_layer(req_meta.starts[index], + req_meta.ends[index], + req_meta.block_ids, + req_meta.layer_id) + self.m_store.get(key, addr, size) + self.request_queue.task_done() + self.get_event.set() diff --git a/vllm_ascend/distributed/mooncake/mooncake_engine.py b/vllm_ascend/distributed/mooncake/mooncake_engine.py new file mode 100644 index 0000000..53c2724 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_engine.py @@ -0,0 +1,489 @@ +# Standard +import math +import threading +import time +from typing import Generator, List, Optional, Union + +# Third Party +import torch +from vllm.config import VllmConfig +from vllm.utils import get_kv_cache_torch_dtype, logger + +from vllm_ascend.distributed.mooncake.config_data import ( + ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata, + MooncakeEngineMetadata) +from vllm_ascend.distributed.mooncake.kv_transfer import ( + KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread, + KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread) +from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore + + +class MooncakeEngine: + #The main class for the cache engine. + + def __init__( + self, + vllm_config: VllmConfig, + use_layerwize: bool, + ): + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + self.use_mla = False + if (hasattr(model_config, "use_mla") + and isinstance(model_config.use_mla, bool) + and model_config.use_mla): + self.use_mla = True + self.use_layerwise = use_layerwize + self.tp_rank = parallel_config.rank + self.tp_size = parallel_config.tensor_parallel_size + self.kv_role = vllm_config.kv_transfer_config.kv_role + self.block_size = vllm_config.cache_config.block_size + self.current_layer = 0 + # self.use_mla = first_kv_cache_tuple[0].size( + # -1) != first_kv_cache_tuple[1].size(-1) + self.num_layers = model_config.get_num_layers(parallel_config) + self.block_size = vllm_config.cache_config.block_size + num_kv_head = model_config.get_num_kv_heads(parallel_config) + head_size = model_config.get_head_size() + kv_dtype = get_kv_cache_torch_dtype( + vllm_config.cache_config.cache_dtype, model_config.dtype) + self.hidden_dim_size = num_kv_head * head_size + if self.use_mla: + kv_shape = (self.num_layers, 1, self.block_size, 1, head_size) + else: + kv_shape = (self.num_layers, 2, self.block_size, num_kv_head, + head_size) + self.metadata = MooncakeEngineMetadata( + model_config.model, + parallel_config.world_size, + parallel_config.rank, + kv_dtype, + kv_shape, + self.block_size, + self.use_mla, + ) + + self.token_database = ChunkedTokenDatabase(self.metadata) + + self.m_store = Mooncakestore(parallel_config) + + self.kv_send_thread: Optional[KVTransferThread] = None + self.kv_recv_thread: Optional[KVTransferThread] = None + + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + _, first_kv_cache_tuple = next(iter(kv_caches.items())) + first_kv_cache = first_kv_cache_tuple[0] + + # TODO(tms): Find a more robust way to detect and handle MLA + if self.use_mla: + # MLA case.[num_block, block_size, 1, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + block_rank = 3 # [block_size, latent_dim] + block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:] + block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:] + self.block_len = [ + first_kv_cache[0].element_size() * math.prod(block_shape_norm), + first_kv_cache[1].element_size() * math.prod(block_shape_pe) + ] + logger.info( + "num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s", + self.num_blocks, block_shape_norm, block_shape_pe) + else: + # [num_block, block_size, num_head, hidden_dim] + self.num_blocks = first_kv_cache.shape[0] + kv_elem_size = first_kv_cache.element_size() + block_rank = 3 # [block_size, kv_heads, head_dim] + block_shape = first_kv_cache.shape[-block_rank:] + self.block_len = [kv_elem_size * math.prod(block_shape)] + logger.info("num_blocks: %s, block_shape: %s", self.num_blocks, + block_shape) + + logger.info("Registering KV_Caches. use_mla: %s, shape %s", + self.use_mla, first_kv_cache.shape) + + self.kv_caches = kv_caches + self.m_store.set_kv_caches(kv_caches.values()) + self.kv_caches_base_addr = [] + for cache_or_caches in kv_caches.values(): + # Normalize to always be a list of caches + if self.use_mla: + for i, cache in enumerate(cache_or_caches, 0): + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + else: + cache_list = [cache_or_caches + ] if self.use_mla else cache_or_caches + for cache in cache_list: + base_addr = cache.data_ptr() + self.kv_caches_base_addr.append(base_addr) + + if self.use_layerwise: + self.get_event = threading.Event() + if self.kv_role == 'kv_producer': + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreLayerSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending, + self.num_layers) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreLayerRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event, self.get_event) + self.kv_recv_thread.start() + ready_event.wait() + else: + if self.kv_role == 'kv_producer': + ready_event_sending = threading.Event() + self.kv_send_thread = KVCacheStoreSendingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, + self.block_len, self.block_size, ready_event_sending) + self.kv_send_thread.start() + ready_event = threading.Event() + self.kv_recv_thread = KVCacheStoreRecvingThread( + self.tp_rank, self.tp_size, self.m_store, + self.kv_caches_base_addr, self.token_database, self.block_len, + self.block_size, ready_event) + self.kv_recv_thread.start() + ready_event.wait() + + def start_load_kv(self, metadata: MooncakeConnectorMetadata): + self.current_layer = 0 + self.layerwise_retrievers = [] + for request in metadata.requests: + load_spec = request.load_spec + if load_spec is None or not load_spec.can_load: #load =0 + continue + tokens = request.token_ids + req_id = request.req_id + if (load_spec.mooncake_cached_tokens % self.block_size + != 0) and (load_spec.mooncake_cached_tokens + == tokens.shape[0] - 1): + tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1] + else: + tokens = tokens[:request.load_spec.mooncake_cached_tokens] + masked_token_count = (request.load_spec.vllm_cached_tokens // + self.block_size * self.block_size) + token_mask = torch.ones_like(tokens, dtype=torch.bool) + token_mask[:masked_token_count] = False + if self.use_layerwise: + layerwise_retriever = self.retrieve_layer( + req_id, + tokens, + request.block_ids, + token_mask, + ) + next(layerwise_retriever) # first layer load + self.layerwise_retrievers.append(layerwise_retriever) + else: + self.kv_recv_thread.add_request( # type: ignore[union-attr] + req_id, + tokens, + request.block_ids, + token_mask, + ) + + def wait_for_layer_load(self) -> None: + """MooncakeConnector does not do layerwise saving.""" + for layerwise_retriever in self.layerwise_retrievers: + ret_token_mask = next(layerwise_retriever) + if self.current_layer == self.num_layers - 1: + assert ret_token_mask is not None + num_retrieved_tokens = ret_token_mask.sum().item() + logger.info(f"Retrieved {num_retrieved_tokens} tokens") + + def save_kv_layer(self, + connector_metadata: MooncakeConnectorMetadata) -> None: + """MooncakeConnector does not save explicitly.""" + if self.current_layer == 0: + self.layerwise_storers = [] + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + # TODO: whether need to remov saveThread + # no lookup, skipmask + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + layerwise_storer = self.store_layer( + req_id, + token_ids, + mask=store_mask, + block_ids=request.block_ids, + ) + self.layerwise_storers.append(layerwise_storer) + for layerwise_storer in self.layerwise_storers: + try: + next(layerwise_storer) + except Exception: + raise + self.current_layer = self.current_layer + 1 + + def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata): + """MooncakeConnector does not save explicitly.""" + for request in connector_metadata.requests: + save_spec = request.save_spec + if save_spec is None or not save_spec.can_save: + continue + + token_ids = request.token_ids + req_id = request.req_id + assert isinstance(token_ids, torch.Tensor) + assert token_ids.is_cpu + + skip_leading_tokens = max( + self.lookup(token_ids, self.use_layerwise), + save_spec.skip_leading_tokens, + ) + if skip_leading_tokens == len(token_ids): + if request.is_last_chunk: + self.kv_send_thread.set_finished_request( # type: ignore[union-attr] + req_id) + continue # skip this request + + skip_leading_tokens = (skip_leading_tokens // self.block_size * + self.block_size) + + store_mask = torch.ones_like(token_ids, dtype=torch.bool) + store_mask[:skip_leading_tokens] = False + + logger.info( + "Storing KV cache for %d out of %d tokens " + "(skip_leading_tokens=%d) for request %s", + len(token_ids) - skip_leading_tokens, + len(token_ids), + skip_leading_tokens, + request.req_id, + ) + + self.kv_send_thread.add_request( # type: ignore[union-attr] + req_id, + token_ids, + request.block_ids, + store_mask, + request.is_last_chunk, + ) + + def retrieve_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[Optional[torch.Tensor], None, None]: + """ + Retrieve the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the KV transfer which + will be passed into the npu_transfer. + + return: A generator that yields Optional[torch.Tensor]. The tensor will + be the boolean mask indicating which tokens are retrieved and will + only be returned in the last iteration. + """ + + if mask is not None: + num_required_tokens = torch.sum(mask).item() + else: + num_required_tokens = len(tokens) + + ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu") + + starts = [] + ends = [] + keys = [] + first_flag = True + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) + ret_mask[start:end] = True + + if keys: + # Transpose the keys into layer major format + keys = [list(row) for row in zip(*keys)] # [num_layer,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + if not first_flag: + is_finish = self.get_event.wait(timeout=3) #try---cache + if not is_finish: + logger.info("Layerwise get failed") + self.get_event.clear() + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_recv_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + first_flag = False + yield None + else: + # If no cache are found, we still need to yield to avoid + # `StopIteration` + for layer_id in range(self.num_layers): + yield None + + retrieved_tokens = torch.sum(ret_mask) + logger.debug(f"Retrieved {retrieved_tokens} " + f"out of {num_required_tokens} " + f"out of total {len(tokens)} tokens") + + yield ret_mask + + def store_layer( + self, + req_id: str, + tokens: torch.Tensor, + block_ids: list[int], + mask: Optional[torch.Tensor] = None, + ) -> Generator[None, None, None]: + """ + Store the KV cache in a layerwise manner. + + :param torch.Tensor tokens: The tokens of the corresponding KV caches. + + :param Optional[torch.Tensor] mask: The mask for the tokens. Should + have the same length as tokens. And the mask should ALWAYS be like + FFFFFTTTTTTT, where True means the tokens needs to be matched. + + :param **kwargs: The additional arguments for the storage backend which + will be passed into the gpu_connector. + + return: A generator that yields None. In the first iteration, the + generator allocates the memory objects for all layers and moves + the KV cache of the first layer from GPU to CPU. In the next + iterations, it moves the KV cache of layer i from GPU to the memory + objects (on CPU) and puts the memory objects of layer i-1 to the + storage backends. In the last iteration, it puts the memory objects + of the last layer to the storage backends. + """ + + if mask is not None: + num_stored_tokens = torch.sum(mask).item() + else: + num_stored_tokens = len(tokens) + + starts = [] + ends = [] + keys = [] + for start, end, key in self.token_database.process_tokens( + tokens, mask): + keys_multi_layer = key.split_layers(self.num_layers) + starts.append(start) + ends.append(end) + keys.append(keys_multi_layer) #[block_num,layer_num] + + if keys: + keys = [list(row) for row in zip(*keys)] #[layer_num,block_num] + for layer_id, keys_multi_chunk in enumerate(keys): + req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk, + starts, ends, block_ids, + layer_id) + self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg] + req_meta) # type: ignore[union-attr, call-arg, arg-type] + yield + else: + for layer_id in range(self.num_layers): + yield + logger.debug( + f"Stored {num_stored_tokens} out of total {len(tokens)} tokens") + + def get_finished(self) -> tuple[set[str], set[str]]: + done_sending = ( + self.kv_send_thread. + get_and_clear_finished_requests( # type: ignore[union-attr] + ) if self.kv_role == 'kv_producer' else set()) + done_recving = self.kv_recv_thread.get_and_clear_finished_requests( # type: ignore[union-attr] + ) + + logger.debug( + "Number of completed KV cache send requests: %d, receive " + "requests: %d, tp_rank:%d", len(done_sending), len(done_recving), + self.tp_rank) + return done_sending, done_recving + + def wait_layer_transfer_finish(self): + time.sleep(10) + pass + + def lookup( + self, + tokens: Union[torch.Tensor, List[int]], + use_layerwise: bool, + ) -> int: + """ + Checks the existence of KV cache of the tokens from the cache engine. + + :param tokens: the input tokens, with shape [seq_len] + + :return: An int indicating how many prefix tokens are cached. + """ + end = 0 + + for start, end, key in self.token_database.process_tokens(tokens): + try: + if use_layerwise: + keys = [] + keys_multi_layer = key.split_layers(self.num_layers) + for key in keys_multi_layer: + keys.append(key.to_string()) + # batch is_exists + ress = self.m_store.batch_exists(keys) + res = 1 + for value in ress: + if value != 1: + res = 0 + break + else: + res = self.m_store.exists(key) + if res == 1: + continue + else: + return start + except Exception as e: + logger.warning(f"Remote connection failed in contains: {e}") + return start + + # all tokens where found, return the maximal end + return end + + def close(self) -> None: + """Close the cache engine and free all the resources""" + self.m_store.close() diff --git a/vllm_ascend/distributed/mooncake/mooncake_store.py b/vllm_ascend/distributed/mooncake/mooncake_store.py new file mode 100644 index 0000000..2383749 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store.py @@ -0,0 +1,88 @@ +# Standard +import os + +# Third Party +from vllm.config import ParallelConfig +from vllm.distributed.parallel_state import get_tensor_model_parallel_rank +from vllm.utils import logger + +from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey + +from .config_data import MooncakeStoreConfig + +METADATA_BYTES_LEN = 24 +BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790")) + + +class Mooncakestore(): + + def __init__(self, parallel_config: ParallelConfig): + try: + from mooncake.store import MooncakeDistributedStore # type: ignore + except ImportError as e: + raise ImportError( + "Please install mooncake by following the instructions at " + "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501 + "to run vLLM with MooncakeConnector.") from e + tp_rank = get_tensor_model_parallel_rank() + tp_size = parallel_config.tensor_parallel_size + dp_rank = parallel_config.data_parallel_rank_local + all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None) + if not all_device_ids: + device_ids_list = list( + range(dp_rank * tp_size, (dp_rank + 1) * tp_size)) + else: + device_ids_list = list(map(int, all_device_ids.split(','))) + assert len(device_ids_list) > tp_rank + device_id = device_ids_list[tp_rank] + self.config = MooncakeStoreConfig.load_from_env() + if self.config.protocol == "ascend": + local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \ + ":npu_" + str(device_id) + else: + local_hostname = self.config.local_hostname + self.store = MooncakeDistributedStore() + ret = self.store.setup(local_hostname, self.config.metadata_server, + self.config.global_segment_size, + self.config.local_buffer_size, + self.config.protocol, self.config.device_name, + self.config.master_server_address) + if ret != 0: + msg = "Initialize mooncake failed." + logger.error(msg) + raise RuntimeError(msg) + + def set_kv_caches(self, kvcache): + self.kvcache = list(kvcache) + + def exists(self, key: MooncakeEngineKey) -> bool: + return self.store.is_exist(key.to_string()) == 1 + + def batch_exists(self, keys: list[str]) -> list[bool]: + return self.store.batch_is_exist(keys) + + def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + expect_res = sum(size) + key_str = key.to_string() + try: + res = self.store.batch_get_into_ascend(key_str, addr, size) + if res[0] != expect_res: + logger.error(f"Failed to get key: [{key_str}] .") + except Exception: + logger.error(f"Failed to get key: [{key_str}] .") + return res + + def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]): + key_str = key.to_string() + try: + ret = self.store.batch_put_from_ascend(key_str, addr, size) + if ret[0] != 0: + logger.error(f"Failed to put key {key_str}.") + except Exception: + logger.error(f"Failed to put key {key_str}.") + + return ret + + def close(self): + self.store.close() + logger.info("Closed the mooncake store connection") \ No newline at end of file diff --git a/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py new file mode 100644 index 0000000..6254e47 --- /dev/null +++ b/vllm_ascend/distributed/mooncake/mooncake_store_connector_v1.py @@ -0,0 +1,484 @@ +import threading +from typing import Any, Optional + +import torch +import vllm.envs as envs +import zmq +from vllm.attention.backends.abstract import AttentionMetadata +from vllm.config import VllmConfig +from vllm.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) +from vllm.forward_context import ForwardContext +from vllm.utils import logger, make_zmq_socket +from vllm.v1.core.kv_cache_manager import KVCacheBlocks +from vllm.v1.core.sched.output import SchedulerOutput +from vllm.v1.request import Request +from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder + +from vllm_ascend.distributed.mooncake.config_data import ( + LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker) +from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine + + +class MooncakeConnectorV1(KVConnectorBase_V1): + + def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole): + super().__init__(vllm_config=vllm_config, role=role) + self.kv_role = vllm_config.kv_transfer_config.kv_role + + self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get( + "use_layerwise", False) + + self.kv_caches: dict[str, torch.Tensor] = {} + + self._block_size = vllm_config.cache_config.block_size + + self.sended_but_unfinished_reqs: set[str] = set() + + if role == KVConnectorRole.SCHEDULER: + self.connector_scheduler = MooncakeStoreConnectorV1Scheduler( + vllm_config, self.use_layerwise) + else: + self.connector_worker = MooncakeEngine( + vllm_config, + self.use_layerwise, + ) + + assert self.connector_worker is not None + if vllm_config.parallel_config.rank == 0: + self.lookup_server = MooncakeLookupServer( + self.connector_worker, vllm_config, self.use_layerwise) + + ############################################################ + # Scheduler Side Methods + ############################################################ + + def get_num_new_matched_tokens( + self, request: "Request", + num_computed_tokens: int) -> tuple[int, bool]: + assert self.connector_scheduler is not None + return self.connector_scheduler.get_num_new_matched_tokens( + request, num_computed_tokens) + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + assert self.connector_scheduler is not None + return self.connector_scheduler.update_state_after_alloc( + request, blocks, num_external_tokens) + + def build_connector_meta( + self, + scheduler_output: SchedulerOutput, + ) -> KVConnectorMetadata: + assert self.connector_scheduler is not None + return self.connector_scheduler.build_connector_meta(scheduler_output) + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request, block_ids) + + ############################################################ + # Worker Side Methods + ############################################################ + def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): + assert self.connector_worker is not None + self.connector_worker.register_kv_caches(kv_caches) + + def start_load_kv(self, forward_context: "ForwardContext", + **kwargs) -> None: + assert self.connector_worker is not None + assert isinstance(self._get_connector_metadata(), + MooncakeConnectorMetadata) + self.connector_worker.start_load_kv(self._get_connector_metadata()) + + def wait_for_layer_load(self, layer_name: str) -> None: + """MooncakeStoreConnector does not do layerwise saving.""" + if not self.use_layerwise: + return + self.connector_worker.wait_for_layer_load() + + def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor, + attn_metadata: "AttentionMetadata", **kwargs) -> None: + """MooncakeStoreConnector does not save explicitly.""" + if not self.use_layerwise: + return + + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + self.connector_worker.save_kv_layer(self._get_connector_metadata()) + + def wait_for_save(self): + """MooncakeStoreConnector does not save explicitly.""" + if self.kv_role == "kv_consumer": + # Don't do save if the role is kv_consumer + return + + if self.use_layerwise: + self.connector_worker.wait_layer_transfer_finish() + return + + self.connector_worker.wait_for_save(self._get_connector_metadata()) + + def get_finished(self, + finished_req_ids: set[str]) -> tuple[set[str], set[str]]: + """Get the finished recving and sending requests.""" + assert self.connector_worker is not None + meta = self._get_connector_metadata() + done_sending, done_recving = self.connector_worker.get_finished() + sended_and_finished: set[str] = set() + for item in list(self.sended_but_unfinished_reqs): + if item not in meta.unfinished_request_ids: + sended_and_finished.add(item) + self.sended_but_unfinished_reqs.remove(item) + for item in done_sending: + if item in meta.unfinished_request_ids: + self.sended_but_unfinished_reqs.add(item) + else: + sended_and_finished.add(item) + + return sended_and_finished, done_recving + + +def get_zmq_rpc_path_mooncake( + vllm_config: Optional["VllmConfig"] = None, ) -> str: + base_url = envs.VLLM_RPC_BASE_PATH + # Default to 0 if not configured + rpc_port = 0 + if vllm_config is not None: + rpc_port = vllm_config.kv_transfer_config.get_from_extra_config( + "mooncake_rpc_port", 0) + logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port) + return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}" + + +class MooncakeStoreConnectorV1Scheduler: + + def __init__(self, vllm_config: "VllmConfig", use_layerwise): + self.client = MooncakeLookupClient(vllm_config) + self.use_layerwise = use_layerwise + self.kv_role = vllm_config.kv_transfer_config.kv_role + # request_id -> (vllm cached tokes, mooncake cached tokens) + self.load_specs: dict[str, LoadSpec] = {} + self._block_size = vllm_config.cache_config.block_size + # request_id -> full_token_ids + self._request_trackers: dict[str, RequestTracker] = {} + # Whether to discard partial chunks + self._discard_partial_chunks = ( + vllm_config.kv_transfer_config.get_from_extra_config( + "discard_partial_chunks", True)) + self._unfinished_requests: dict[str, tuple[Request, list[int]]] = {} + self._unfinished_request_ids: set[str] = set() + + def get_num_new_matched_tokens( + self, + request: "Request", + num_computed_tokens: int, + ) -> tuple[int, bool]: + """ + Check for external KV cache hit. + + Args: + request (Request): the request object. + num_computed_tokens (int): the number of locally + computed tokens for this request + + Returns: + the number of tokens that can be loaded from the + external KV cache beyond what is already computed. + """ + + if self._discard_partial_chunks: + token_block_end = len(request.prompt_token_ids + ) // self._block_size * self._block_size + token_ids = torch.tensor( + request.prompt_token_ids[:token_block_end]) + else: + token_ids = torch.tensor(request.prompt_token_ids) + + num_external_hit_tokens = self.client.lookup(token_ids) + + if num_external_hit_tokens == request.num_tokens: + num_external_hit_tokens -= 1 + + need_to_allocate = num_external_hit_tokens - num_computed_tokens + + logger.info( + "Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d", + request.request_id, + request.num_tokens, + num_external_hit_tokens, + need_to_allocate, + ) + + if need_to_allocate <= 0: + return 0, False + + self.load_specs[request.request_id] = LoadSpec( + vllm_cached_tokens=num_computed_tokens, + mooncake_cached_tokens=num_external_hit_tokens, + can_load=False, + ) + + return need_to_allocate, not self.use_layerwise + + def update_state_after_alloc(self, request: "Request", + blocks: "KVCacheBlocks", + num_external_tokens: int): + """ + Update KVConnector state after temporary buffer alloc. + + For SharedStorageConnector, update _request_needs_load + if the CacheManager this allocated blocks for us. + """ + local_block_ids = [] + if num_external_tokens > 0: + local_block_ids = blocks.get_block_ids()[0] + + self._unfinished_requests[request.request_id] = (request, + local_block_ids) + self._unfinished_request_ids.add(request.request_id) + if request.request_id not in self.load_specs: + # No KV tokens from external KV cache, return + return + + if num_external_tokens == 0: + # No need to load anything + self.load_specs[request.request_id].can_load = False + return + + assert ( + num_external_tokens > 0 and num_external_tokens + == self.load_specs[request.request_id].mooncake_cached_tokens - + self.load_specs[request.request_id].vllm_cached_tokens + ), (f"Mismatch in number of tokens: {num_external_tokens} vs " + f"{self.load_specs[request.request_id].mooncake_cached_tokens} - " + f"{self.load_specs[request.request_id].vllm_cached_tokens}" + f" for request {request.request_id}") + + self.load_specs[request.request_id].can_load = True + + def build_connector_meta( + self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: + """Attach the connector metadata to the request object. + + This function should NOT modify other fields in the scheduler_output + except the `kv_connector_metadata` field. + Also, calling this function will reset the state of the connector. + + Args: + scheduler_output (SchedulerOutput): the scheduler output object. + """ + + force_skip_save = self.kv_role == "kv_consumer" + + for finished_req_id in scheduler_output.finished_req_ids: + self._request_trackers.pop(finished_req_id, None) + self._unfinished_requests.pop(finished_req_id, None) + self._unfinished_request_ids.remove(finished_req_id) + + meta = MooncakeConnectorMetadata(self._unfinished_request_ids) + + for request in scheduler_output.scheduled_new_reqs: + # Right now, we only load KV for new requests + load_spec = self.load_specs.pop(request.req_id, None) + num_tokens_to_compute = ( + request.num_computed_tokens + + scheduler_output.num_scheduled_tokens[request.req_id]) + request_tracker = RequestTracker.from_new_request( + request, num_tokens_to_compute) + self._request_trackers[request.req_id] = request_tracker + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else len( + request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + cached_reqs = scheduler_output.scheduled_cached_reqs + if isinstance(cached_reqs, list) and not force_skip_save: + for i, req in enumerate(cached_reqs): + request_tracker = self._request_trackers[req.req_id] + request_tracker.update(req.new_token_ids, req.new_block_ids) + last_chunk_tokens_num = ((len(req.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(req.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + elif not force_skip_save: + for i, req_id in enumerate(cached_reqs.req_ids): + request_tracker = self._request_trackers[req_id] + num_new_tokens = scheduler_output.num_scheduled_tokens[req_id] + req_tuple = self._unfinished_requests.get(req_id) + if req_tuple: + request = req_tuple[0] + num_current_tokens = len(request_tracker.token_ids) + new_token_ids = request.all_token_ids[ + num_current_tokens:num_current_tokens + num_new_tokens] + else: + raise ValueError( + f"Request {req_id} is not in _unfinished_requests, " + f"but it is scheduled to be cached") + new_block_ids = cached_reqs.new_block_ids[i] + if not new_block_ids: + continue + request_tracker.update(new_token_ids, new_block_ids) + # decode not save + if len(request_tracker.token_ids) > len( + request.prompt_token_ids): + continue + + last_chunk_tokens_num = ((len(request.prompt_token_ids) // + self._block_size * self._block_size) + if self._discard_partial_chunks else + len(request.prompt_token_ids)) + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=None, + skip_save=force_skip_save, + is_last_chunk=len(request_tracker.token_ids) + >= last_chunk_tokens_num, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + + request_ids = [ + req.req_id for req in scheduler_output.scheduled_new_reqs + ] + for request_id, (request, + block_ids) in self._unfinished_requests.items(): + if request_id not in request_ids and request_id not in cached_reqs.req_ids: + load_spec = self.load_specs.pop(request_id, None) + if not load_spec: + continue + num_tokens_to_compute = load_spec.mooncake_cached_tokens + if (num_tokens_to_compute % self._block_size + != 0) and (num_tokens_to_compute + == len(request.prompt_token_ids) - 1): + num_tokens_to_compute = num_tokens_to_compute + 1 + request_tracker = RequestTracker( + req_id=request_id, + token_ids=request.prompt_token_ids[:num_tokens_to_compute]. + copy(), + allocated_block_ids=block_ids, + num_saved_tokens=0, + ) + + self._request_trackers[request_id] = request_tracker + + req_meta = ReqMeta.from_request_tracker( + request_tracker, + self._block_size, + load_spec=load_spec, + skip_save=None, + discard_partial_chunks=self._discard_partial_chunks, + ) + if req_meta is not None: + meta.add_request(req_meta) + return meta + + def request_finished( + self, + request: "Request", + block_ids: list[int], + ) -> tuple[bool, Optional[dict[str, Any]]]: + """ + Once a request is finished, determine whether request blocks + should be freed now or will be sent asynchronously and freed later. + """ + if self.kv_role == "kv_consumer": + return False, None + if self._request_trackers[request.request_id].num_saved_tokens <= 0: + return False, None + delay_free_blocks = len(block_ids) > 0 + if delay_free_blocks: + logger.info("Delaying free of %d blocks for request %s", + len(block_ids), request.request_id) + return delay_free_blocks, None + + +class MooncakeLookupClient: + + def __init__(self, vllm_config: "VllmConfig"): + self.encoder = MsgpackEncoder() + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REQ, # type: ignore[attr-defined] + bind=False, + ) + + def lookup(self, token_ids: torch.Tensor) -> int: + request = self.encoder.encode(token_ids) + self.socket.send_multipart(request, copy=False) + resp = self.socket.recv() + result = int.from_bytes(resp, "big") + return result + + def close(self): + self.socket.close(linger=0) + + +class MooncakeLookupServer: + + def __init__( + self, + mooncake_engine: MooncakeEngine, + vllm_config: "VllmConfig", + use_layerwise: bool, + ): + self.decoder = MsgpackDecoder(torch.Tensor) + self.ctx = zmq.Context() # type: ignore[attr-defined] + socket_path = get_zmq_rpc_path_mooncake(vllm_config) + self.socket = make_zmq_socket( + self.ctx, + socket_path, + zmq.REP, # type: ignore[attr-defined] + bind=True, + ) + + self.mooncake_engine = mooncake_engine + self.running = True + + def process_request(): + while self.running: + frames = self.socket.recv_multipart(copy=False) + token_ids = self.decoder.decode(frames) + result = self.mooncake_engine.lookup(token_ids, use_layerwise) + response = result.to_bytes(4, "big") + self.socket.send(response) + + self.thread = threading.Thread(target=process_request, daemon=True) + self.thread.start() + + def close(self): + self.socket.close(linger=0) + # TODO: close the thread! \ No newline at end of file diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index f867e5a..7e553ae 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1811,13 +1811,9 @@ class NPUModelRunner(LoRAModelRunnerMixin): if self.drafter and self.drafter.name == SpecDcodeType.EAGLE3: hidden_states, aux_hidden_states = hidden_states - kv_connector_output = None - if finished_sending is not None or finished_recving is not None: - kv_connector_output = KVConnectorOutput( - finished_sending=finished_sending, - finished_recving=finished_recving) - else: - kv_connector_output = None + kv_connector_output = KVConnectorOutput( + finished_sending=finished_sending, + finished_recving=finished_recving) finished_sending = None finished_recving = None with ProfileExecuteDuration().capture_async("post process"): @@ -2067,8 +2063,6 @@ class NPUModelRunner(LoRAModelRunnerMixin): # For the case of no forward caused by receiving remote kv, # one round of dummy inference is necessary # to prevent hang over the collective calls. - if not finished_sending and not finished_recving: - return EMPTY_MODEL_RUNNER_OUTPUT output = copy.copy(EMPTY_MODEL_RUNNER_OUTPUT) output.kv_connector_output = KVConnectorOutput(