[Feature][main]reconstruction kvpool connector to ascend connector (#4438)

### What this PR does / why we need it?
1.In short, we renamed the existing MooncakeStoreConnector to
AscendStoreConnector and extracted the storage engine interaction logic
into a new Backend class.
Associated RFC:https://github.com/vllm-project/vllm-ascend/issues/4329
2.Fixed the issue where the number of input parameters for the connector
was incorrect, introduced in vllm 0.11.2
### Does this PR introduce _any_ user-facing change?
change MooncakeStoreConnector to AscendStoreConnector
### How was this patch tested?

- vLLM version: v0.11.2

---------

Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
fems14
2025-11-28 18:08:37 +08:00
committed by GitHub
parent 554f16ae1f
commit 5447a039b9
25 changed files with 1489 additions and 1511 deletions

View File

@@ -14,6 +14,6 @@ lora
eplb_swift_balancer
netloader
dynamic_batch
kv_pool_mooncake
kv_pool
external_dp
:::

View File

@@ -1,4 +1,4 @@
# Mooncacke Store Deployment Guide
# Ascend Store Deployment Guide
## Environmental Dependencies
@@ -8,27 +8,30 @@
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
* vLLMmain branch
* vLLM-Ascendmain branch
* Mooncakemain branch
Installation and Compilation Guidehttps://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries
Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine.
An example command for compiling ADXL
`rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install`
Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation
### KV Pooling Parameter Description
**kv_connector_extra_config**: Additional Configurable Parameters for Pooling.
**mooncake_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration.
**lookup_rpc_port**: Port for RPC Communication Between Pooling Scheduler Process and Worker Process: Each Instance Requires a Unique Port Configuration.
**load_async**: Whether to Enable Asynchronous Loading. The default value is false.
**register_buffer**: Whether to Register Video Memory with the Backend. Registration is Not Required When Used with MooncakeConnectorV1; It is Required in All Other Cases. The Default Value is false.
**backend**: Set the storage backend for kvpool, with the default being mooncake.
## Run Mooncake Master
## Example of using Mooncake as a KVCache pooling backend
* Software:
* Mooncakemain branch
### 1.Configure mooncake.json
Installation and Compilation Guidehttps://github.com/kvcache-ai/Mooncake?tab=readme-ov-file#build-and-use-binaries
Make sure to build with `-DUSE_ASCEND_DIRECT` to enable ADXL engine.
An example command for compiling ADXL
`rm -rf build && mkdir -p build && cd build \ && cmake .. -DCMAKE_INSTALL_PREFIX=/opt/transfer-engine/ -DCMAKE_POLICY_VERSION_MINIMUM=3.5 -DUSE_ASCEND_DIRECT=ON -DBUILD_SHARED_LIBS=ON -DBUILD_UNIT_TESTS=OFF \ && make -j \ && make install`
Also, you need to set environment variables to point to them `export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/lib64/python3.11/site-packages/mooncake`, or copy the .so files to the `/usr/local/lib64` directory after compilation
### run mooncake master
#### 1.Configure mooncake.json
The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path where mooncake.json is located.
@@ -54,7 +57,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
**master_server_address**: Configured with the IP and port of the master service.
**global_segment_size**: Expands the kvcache size registered by the PD node to the master.
### 2. Start mooncake_master
#### 2. Start mooncake_master
Under the mooncake folder:
@@ -64,9 +67,9 @@ mooncake_master --port 50088 --eviction_high_watermark_ratio 0.95 --eviction_rat
`eviction_high_watermark_ratio` determines the watermark where Mooncake Store will perform evictionand `eviction_ratio` determines the portion of stored objects that would be evicted.
## Pooling and Prefill Decode Disaggregate Scenario
### Pooling and Prefill Decode Disaggregate Scenario
### 1.Run `prefill` Node and `decode` Node
#### 1.Run `prefill` Node and `decode` Node
Using MultiConnector to simultaneously utilize both p2p connectors and pooled connectors. P2P performs kv_transfer, while pooling creates a larger prefix-cache.
@@ -123,9 +126,10 @@ python3 -m vllm.entrypoints.openai.api_server \
}
},
{
"kv_connector": "MooncakeConnectorStoreV1",
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_producer",
"mooncake_rpc_port":"0"
"lookup_rpc_port":"0",
"backend": "mooncake"
}
]
}
@@ -185,16 +189,17 @@ python3 -m vllm.entrypoints.openai.api_server \
}
},
{
"kv_connector": "MooncakeConnectorStoreV1",
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_consumer",
"mooncake_rpc_port":"1"
"lookup_rpc_port":"1",
"backend": "mooncake"
}
]
}
}' > d.log 2>&1
```
### 2、Start proxy_server.
#### 2、Start proxy_server.
```
bash proxy.sh
@@ -212,7 +217,7 @@ python vllm-ascend/examples/disaggregated_prefill_v1/load_balance_proxy_server_e
--decoder-ports 8200 \
```
### 3. Run Inference
#### 3. Run Inference
Configure the localhost, port, and model weight path in the command to your own settings.
@@ -228,9 +233,9 @@ Long question:
curl -s http://localhost:8000/v1/completions -H "Content-Type: application/json" -d '{ "model": "/xxxxx/Qwen2.5-7B-Instruct", "prompt": "Given the accelerating impacts of climate change—including rising sea levels, increasing frequency of extreme weather events, loss of biodiversity, and adverse effects on agriculture and human health—there is an urgent need for a robust, globally coordinated response. However, international efforts are complicated by a range of factors: economic disparities between high-income and low-income countries, differing levels of industrialization, varying access to clean energy technologies, and divergent political systems that influence climate policy implementation. In this context, how can global agreements like the Paris Accord be redesigned or strengthened to not only encourage but effectively enforce emission reduction targets? Furthermore, what mechanisms can be introduced to promote fair and transparent technology transfer, provide adequate financial support for climate adaptation in vulnerable regions, and hold nations accountable without exacerbating existing geopolitical tensions or disproportionately burdening those with historically lower emissions?", "max_tokens": 256, "temperature":0.0 }'
```
## Pooling and Mixed Deployment Scenario
### Pooling and Mixed Deployment Scenario
### 1、Run Mixed Department Script
#### 1、Run Mixed Department Script
The mixed script is essentially a pure pooling scenario for the P node.
@@ -263,19 +268,17 @@ python3 -m vllm.entrypoints.openai.api_server \
--max-num-batched-tokens 4096 \
--kv-transfer-config \
'{
"kv_connector": "MooncakeConnectorStoreV1",
"kv_connector": "AscendStoreConnector",
"kv_role": "kv_both",
"kv_connector_extra_config": {
"register_buffer": true,
"use_layerwise": false,
"mooncake_rpc_port":"0"
"lookup_rpc_port":"1",
"backend": "mooncake"
}
}' > mix.log 2>&1
```
`register_buffer` is set to `false` by default and need to be set to `true` only in PD-mixed scenario.
### 2. Run Inference
#### 2. Run Inference
Configure the localhost, port, and model weight path in the command to your own settings. The requests sent will only go to the port where the mixed deployment script is located, and there is no need to start a separate proxy.

View File

@@ -1,6 +1,13 @@
import sys
import types
import unittest
from unittest.mock import MagicMock
from vllm_ascend.distributed.mooncake.config_data import (
fake_engine = types.ModuleType("mooncake.engine")
fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined]
sys.modules["mooncake.engine"] = fake_engine
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import ( # noqa: E402
_convert_to_bytes, _parse_global_segment_size)

View File

@@ -1051,7 +1051,7 @@ class TestMooncakeConnectorWorker(unittest.TestCase):
'vllm_ascend.distributed.mooncake_connector.string_to_int64_hash',
mock_string_to_int64_hash),
patch(
'vllm_ascend.distributed.mooncake.transfer_engine.TransferEngine',
'vllm_ascend.distributed.mooncake_transfer_engine.TransferEngine',
return_value=self.mock_transfer_engine),
patch(
'vllm_ascend.distributed.mooncake_connector.KVCacheSendingThread',

View File

@@ -31,8 +31,13 @@ def register_connector():
KVConnectorFactory.register_connector(
"MooncakeConnectorStoreV1",
"vllm_ascend.distributed.mooncake.mooncake_store_connector_v1",
"MooncakeConnectorV1")
"vllm_ascend.distributed.kvpool.ascend_store_connector",
"AscendStoreConnector")
KVConnectorFactory.register_connector(
"AscendStoreConnector",
"vllm_ascend.distributed.kvpool.ascend_store_connector",
"AscendStoreConnector")
KVConnectorFactory.register_connector(
"MooncakeLayerwiseConnector",

View File

@@ -29,6 +29,7 @@ if TYPE_CHECKING:
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.forward_context import ForwardContext
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
@@ -58,7 +59,10 @@ class CPUOffloadingConnectorMetadata(KVConnectorMetadata):
class CPUOffloadingConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: "VllmConfig", role: KVConnectorRole):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
if not vllm_config.cache_config.enable_prefix_caching:
self.connector_scheduler: Optional[
CPUOffloadingConnectorScheduler] = None

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,194 @@
import threading
from typing import Any, Optional
import torch
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.utils import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder
from vllm_ascend.distributed.kvpool.pool_scheduler import (
KVPoolScheduler, get_zmq_rpc_path_lookup)
from vllm_ascend.distributed.kvpool.pool_worker import KVPoolWorker
class AscendStoreConnector(KVConnectorBase_V1):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config=vllm_config,
role=role,
kv_cache_config=kv_cache_config)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
connector_name = vllm_config.kv_transfer_config.kv_connector
if connector_name == "MooncakeConnectorStoreV1":
logger.warning(
"It is recommended to use the AscendStoreConnector, as the MoonCakeStoreConnector will be removed in the future."
)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = KVPoolScheduler(vllm_config,
self.use_layerwise)
else:
self.connector_worker = KVPoolWorker(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = LookupKeyServer(self.connector_worker,
vllm_config,
self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
class LookupKeyServer:
def __init__(
self,
pool_worker: KVPoolWorker,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder()
self.decoder_tensor = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.pool_worker = pool_worker
self.running = True
self.use_layerwise = use_layerwise
def process_request():
while self.running:
all_frames = self.socket.recv_multipart(copy=False)
token_len = int.from_bytes(all_frames[0], byteorder="big")
hash_frames = all_frames[1:]
hashes_str = self.decoder.decode(hash_frames)
result = self.pool_worker.lookup_scheduler(
token_len, hashes_str, self.use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!

View File

@@ -0,0 +1 @@

View File

@@ -0,0 +1,29 @@
from abc import ABC, abstractmethod
from vllm.config import ParallelConfig
class Backend(ABC):
def __init__(self, parallel_config: ParallelConfig):
pass
def set_device(self):
pass
def register_buffer(self, ptrs: list[int], lengths: list[int]):
pass
@abstractmethod
def exists(self, keys: list[str]) -> list[int]:
pass
@abstractmethod
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass
@abstractmethod
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
pass

View File

@@ -0,0 +1,74 @@
# Standard
from enum import Enum
import torch
from vllm.config import ParallelConfig
from vllm.utils import logger
from vllm_ascend.distributed.kvpool.backend.backend import Backend
class MmcDirect(Enum):
COPY_L2G = 0
COPY_G2L = 1
COPY_G2H = 2
COPY_H2G = 3
class MemcacheBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from memcache import DistributedObjectStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install memcache by following the instructions at "
"https://gitee.com/ascend/memfabric_hybrid " # noqa: E501
"to run vLLM with MemcacheConnector.") from e
try:
self.rank = parallel_config.rank
self.store = DistributedObjectStore()
res = self.store.init(self.rank)
assert res == 0
except ValueError as e:
logger.error("Configuration loading failed: %s", e)
raise
except Exception as exc:
logger.error(
"An error occurred while loading the configuration: %s", exc)
raise
def set_device(self):
device = torch.device(f"npu:{self.rank}")
torch.npu.set_device(device)
def register_buffer(self, ptrs: list[int], sizes: list[int]):
for ptr, size in zip(ptrs, sizes):
ret_value = self.store.register_buffer(ptr, size)
if ret_value != 0:
raise RuntimeError("Memcache memory registration failed.")
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def get(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_get_into_layers(key, addr, size,
MmcDirect.COPY_G2L.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {key}. {e}")
def put(self, key: list[str], addr: list[list[int]],
size: list[list[int]]):
try:
res = self.store.batch_put_from_layers(key, addr, size,
MmcDirect.COPY_L2G.value)
for value in res:
if value != 0:
logger.error(f"Failed to get key {key},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {key},error:{e}")

View File

@@ -0,0 +1,188 @@
# Standard
import json
import os
import re
from dataclasses import dataclass
from typing import Union
# Third Party
from vllm.config import ParallelConfig
from vllm.utils import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
class MooncakeBackend(Backend):
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend":
local_hostname = get_ip()
transfer_engine = global_te.get_transfer_engine(local_hostname,
device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def register_buffer(self, ptrs: list[int], lengths: list[int]):
global_te.register_buffer(ptrs, lengths)
def exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def put(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]]):
try:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}, error:{e}")
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: Union[int, str]
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size(
config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE)),
local_buffer_size=(config.get("local_buffer_size",
DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int:
"""
Parse storage size strings with support for units: GB, MB, KB, B
Args:
value: Input value (int, str, or other convertible types)
Returns:
int: Size in bytes
Raises:
ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types
"""
if isinstance(value, int):
return value
elif not isinstance(value, str):
try:
return int(value)
except (TypeError, ValueError) as e:
raise TypeError(
f"Unsupported type for global_segment_size: {type(value)}"
) from e
cleaned_input = value.strip().lower()
if not cleaned_input:
raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte
}
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
match = re.match(pattern, cleaned_input)
if not match:
raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1)
unit = match.group(2) or 'b'
multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int,
original_input: str) -> int:
"""
Convert numeric string to byte count
Args:
number_str: Numeric portion of input
multiplier: Unit conversion factor
original_input: Original input string (for error messages)
Returns:
int: Byte count
Raises:
ValueError: For invalid numbers or negative results
"""
try:
numeric_value = float(number_str)
except ValueError:
raise ValueError(
f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count
try:
byte_count = int(numeric_value * multiplier)
except OverflowError:
raise ValueError(f"Storage size too large: '{original_input}'")
return byte_count

View File

@@ -0,0 +1,364 @@
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import NewRequestData
#Parameters related to the key
@dataclass
class KeyMetadata:
"""name of the LLM model"""
model_name: str
""" worker id when running under a distributed setting """
head_or_tp_rank: int
@dataclass(order=True)
class PoolKey:
key_metadata: KeyMetadata
chunk_hash: str
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.chunk_hash,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}"
)
def split_layers(self, num_layers: int) -> List["LayerPoolKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerPoolKey(
self.key_metadata,
self.chunk_hash,
layer_id,
))
return keys
@dataclass(order=True)
class LayerPoolKey(PoolKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.key_metadata.model_name,
self.key_metadata.head_or_tp_rank,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (
f"{self.key_metadata.model_name}"
f"@head_or_tp_rank:{self.key_metadata.head_or_tp_rank}@{self.chunk_hash}@{self.layer_id}"
)
class ChunkedTokenDatabase():
def __init__(self, metadata: KeyMetadata, block_size: int, use_mla: bool):
self.metadata = metadata
self.block_size = block_size
self.use_mla = use_mla
self.kv_caches_base_addr: list[int] = []
self.block_len: list[int] = []
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return PoolKey(
self.metadata,
chunk_hash,
)
def set_kv_caches_base_addr(self, kv_caches_base_addr: list[int]):
self.kv_caches_base_addr = kv_caches_base_addr
def set_block_len(self, block_len: list[int]):
self.block_len = block_len
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def process_tokens(
self,
token_len: int,
block_hashes: Union[list[BlockHash], list[str]],
mask_num: int = 0,
) -> Iterable[Tuple[int, int, PoolKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if not block_hashes:
return
if not isinstance(block_hashes[0], str):
block_hashes = [
h.hex() # type: ignore[union-attr]
for h in block_hashes
]
start_idx = 0
for chunk_id, hash_val in enumerate(block_hashes):
start_idx = chunk_id * self.block_size
if start_idx >= token_len:
break
end_idx = min(start_idx + self.block_size, token_len)
if start_idx < mask_num:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
#Parameters related to the connector metadata
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in kvpool
kvpool_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_len: int
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_len=num_tokens_to_compute,
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_len = self.token_len + len(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_len_chunk: int
block_ids: list[int]
block_hashes: list[BlockHash]
can_save: Optional[bool] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
block_hashes: list[BlockHash] = [],
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_len = tracker.token_len
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.kvpool_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{not skip_save}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_len_chunk=num_tokens_to_save,
block_ids=tracker.allocated_block_ids,
can_save=not skip_save,
load_spec=load_spec,
block_hashes=block_hashes,
is_last_chunk=is_last_chunk,
)
class AscendConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerPoolKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
is_last_chunk: bool = True

View File

@@ -0,0 +1,246 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.kvpool.backend.backend import Backend
# isort: off
from vllm_ascend.distributed.kvpool.config_data import (ChunkedTokenDatabase,
LasyerMultiBlockReqMeta
)
# isort: on
class KVTransferThread(threading.Thread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.m_store = m_store
self.ready_event = ready_event
self.tp_rank = tp_rank
self.token_database = token_database
self.done_task_lock = threading.Lock()
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def add_request(
self,
req_id: str,
token_len: int,
block_ids: list[int],
block_hashes: list[BlockHash],
mask_num: int = 0,
is_last_chunk: Optional[bool] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"token_len": token_len,
"block_ids": block_ids,
"block_hashes": block_hashes,
"mask_num": mask_num,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.m_store.set_device()
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheSendingThread")
self.put_step = put_step
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
block_hashes = req_meta["block_hashes"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
token_len = req_meta["token_len"]
mask_num = req_meta["mask_num"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
block_hashes = req_meta["block_hashes"]
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, put_step: int, ready_event: threading.Event,
num_layers: int):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
self.put_step = put_step
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_tp = key_list[self.tp_rank % self.put_step::self.put_step]
addr_list_tp = addr_list[self.tp_rank % self.put_step::self.put_step]
size_list_tp = size_list[self.tp_rank % self.put_step::self.put_step]
if key_list_tp:
torch.npu.current_stream().synchronize()
self.m_store.put(key_list_tp, addr_list_tp, size_list_tp)
if req_meta.layer_id == self.final_layer_id and req_meta.is_last_chunk:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, m_store: Backend, token_database: ChunkedTokenDatabase,
tp_rank: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(m_store,
token_database,
tp_rank,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
addr_list = []
size_list = []
key_list = []
for index, key in enumerate(req_meta.keys):
addr, size = self.token_database.prepare_value_layer(
req_meta.starts[index], req_meta.ends[index],
req_meta.block_ids, req_meta.layer_id)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank %
len(key_list):] + key_list[:self.tp_rank %
len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -1,174 +1,33 @@
import threading
from typing import Any, Optional
import torch
import vllm.envs as envs
import zmq
from vllm.attention.backends.abstract import AttentionMetadata
from vllm.config import VllmConfig
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole)
from vllm.forward_context import ForwardContext
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import logger
from vllm.utils.network_utils import make_zmq_socket
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.request import Request
from vllm.v1.serial_utils import MsgpackDecoder, MsgpackEncoder
from vllm.v1.serial_utils import MsgpackEncoder
from vllm_ascend.distributed.mooncake.config_data import (
LoadSpec, MooncakeConnectorMetadata, ReqMeta, RequestTracker)
from vllm_ascend.distributed.mooncake.mooncake_engine import MooncakeEngine
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, LoadSpec, ReqMeta, RequestTracker)
class MooncakeConnectorV1(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
super().__init__(vllm_config=vllm_config, role=role)
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.use_layerwise = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"use_layerwise", False)
self.kv_caches: dict[str, torch.Tensor] = {}
self._block_size = vllm_config.cache_config.block_size
self.sended_but_unfinished_reqs: set[str] = set()
if role == KVConnectorRole.SCHEDULER:
self.connector_scheduler = MooncakeStoreConnectorV1Scheduler(
vllm_config, self.use_layerwise)
else:
self.connector_worker = MooncakeEngine(
vllm_config,
self.use_layerwise,
)
assert self.connector_worker is not None
if vllm_config.parallel_config.rank == 0:
self.lookup_server = MooncakeLookupServer(
self.connector_worker, vllm_config, self.use_layerwise)
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens(
self, request: "Request",
num_computed_tokens: int) -> tuple[int, bool]:
assert self.connector_scheduler is not None
return self.connector_scheduler.get_num_new_matched_tokens(
request, num_computed_tokens)
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
num_external_tokens: int):
assert self.connector_scheduler is not None
return self.connector_scheduler.update_state_after_alloc(
request, blocks, num_external_tokens)
def build_connector_meta(
self,
scheduler_output: SchedulerOutput,
) -> KVConnectorMetadata:
assert self.connector_scheduler is not None
return self.connector_scheduler.build_connector_meta(scheduler_output)
def request_finished(
self,
request: "Request",
block_ids: list[int],
) -> tuple[bool, Optional[dict[str, Any]]]:
assert self.connector_scheduler is not None
return self.connector_scheduler.request_finished(request, block_ids)
############################################################
# Worker Side Methods
############################################################
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
assert self.connector_worker is not None
self.connector_worker.register_kv_caches(kv_caches)
def start_load_kv(self, forward_context: "ForwardContext",
**kwargs) -> None:
assert self.connector_worker is not None
assert isinstance(self._get_connector_metadata(),
MooncakeConnectorMetadata)
self.connector_worker.start_load_kv(self._get_connector_metadata())
def wait_for_layer_load(self, layer_name: str) -> None:
"""MooncakeStoreConnector does not do layerwise saving."""
if not self.use_layerwise:
return
self.connector_worker.wait_for_layer_load()
def save_kv_layer(self, layer_name: str, kv_layer: torch.Tensor,
attn_metadata: "AttentionMetadata", **kwargs) -> None:
"""MooncakeStoreConnector does not save explicitly."""
if not self.use_layerwise:
return
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
self.connector_worker.save_kv_layer(self._get_connector_metadata())
def wait_for_save(self):
"""MooncakeStoreConnector does not save explicitly."""
if self.kv_role == "kv_consumer":
# Don't do save if the role is kv_consumer
return
if self.use_layerwise:
self.connector_worker.wait_layer_transfer_finish()
return
self.connector_worker.wait_for_save(self._get_connector_metadata())
def get_finished(self,
finished_req_ids: set[str]) -> tuple[set[str], set[str]]:
"""Get the finished recving and sending requests."""
assert self.connector_worker is not None
meta = self._get_connector_metadata()
done_sending, done_recving = self.connector_worker.get_finished()
sended_and_finished: set[str] = set()
for item in list(self.sended_but_unfinished_reqs):
if item not in meta.unfinished_request_ids:
sended_and_finished.add(item)
self.sended_but_unfinished_reqs.remove(item)
for item in done_sending:
if item in meta.unfinished_request_ids:
self.sended_but_unfinished_reqs.add(item)
else:
sended_and_finished.add(item)
return sended_and_finished, done_recving
def get_zmq_rpc_path_mooncake(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
rpc_port = vllm_config.kv_transfer_config.get_from_extra_config(
"mooncake_rpc_port", 0)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/mooncake_rpc_port_{rpc_port}"
class MooncakeStoreConnectorV1Scheduler:
class KVPoolScheduler:
def __init__(self, vllm_config: "VllmConfig", use_layerwise):
self.client = MooncakeLookupClient(vllm_config)
self.client = LookupKeyClient(vllm_config)
self.use_layerwise = use_layerwise
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.consumer_is_to_load = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"consumer_is_to_load", False)
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
# request_id -> (vllm cached tokes, mooncake cached tokens)
# request_id -> (vllm cached tokes, kvpool cached tokens)
self.load_specs: dict[str, LoadSpec] = {}
self._block_size = vllm_config.cache_config.block_size
# request_id -> full_token_ids
@@ -201,14 +60,13 @@ class MooncakeStoreConnectorV1Scheduler:
return 0, False
if self._discard_partial_chunks:
token_block_end = len(request.prompt_token_ids
) // self._block_size * self._block_size
token_ids = torch.tensor(
request.prompt_token_ids[:token_block_end])
token_len = len(request.prompt_token_ids
) // self._block_size * self._block_size
else:
token_ids = torch.tensor(request.prompt_token_ids)
token_len = len(request.prompt_token_ids)
num_external_hit_tokens = self.client.lookup(token_ids)
num_external_hit_tokens = self.client.lookup(token_len,
request.block_hashes)
if num_external_hit_tokens == request.num_tokens:
num_external_hit_tokens -= 1
@@ -216,7 +74,7 @@ class MooncakeStoreConnectorV1Scheduler:
need_to_allocate = num_external_hit_tokens - num_computed_tokens
logger.info(
"Reqid: %s, Total tokens %d, mooncake hit tokens: %d, need to load: %d",
"Reqid: %s, Total tokens %d, kvpool hit tokens: %d, need to load: %d",
request.request_id,
request.num_tokens,
num_external_hit_tokens,
@@ -228,11 +86,11 @@ class MooncakeStoreConnectorV1Scheduler:
self.load_specs[request.request_id] = LoadSpec(
vllm_cached_tokens=num_computed_tokens,
mooncake_cached_tokens=num_external_hit_tokens,
kvpool_cached_tokens=num_external_hit_tokens,
can_load=False,
)
return need_to_allocate, self.load_async
return need_to_allocate, self.load_async and not self.use_layerwise
def update_state_after_alloc(self, request: "Request",
blocks: "KVCacheBlocks",
@@ -261,10 +119,10 @@ class MooncakeStoreConnectorV1Scheduler:
assert (
num_external_tokens > 0 and num_external_tokens
== self.load_specs[request.request_id].mooncake_cached_tokens -
== self.load_specs[request.request_id].kvpool_cached_tokens -
self.load_specs[request.request_id].vllm_cached_tokens
), (f"Mismatch in number of tokens: {num_external_tokens} vs "
f"{self.load_specs[request.request_id].mooncake_cached_tokens} - "
f"{self.load_specs[request.request_id].kvpool_cached_tokens} - "
f"{self.load_specs[request.request_id].vllm_cached_tokens}"
f" for request {request.request_id}")
@@ -289,7 +147,7 @@ class MooncakeStoreConnectorV1Scheduler:
self._unfinished_requests.pop(finished_req_id, None)
self._unfinished_request_ids.discard(finished_req_id)
meta = MooncakeConnectorMetadata(self._unfinished_request_ids)
meta = AscendConnectorMetadata(self._unfinished_request_ids)
for request in scheduler_output.scheduled_new_reqs:
# Right now, we only load KV for new requests
@@ -304,12 +162,15 @@ class MooncakeStoreConnectorV1Scheduler:
self._block_size * self._block_size)
if self._discard_partial_chunks else len(
request.prompt_token_ids))
request_tuple = self._unfinished_requests.get(request.req_id)
request_real = request_tuple[0] # type: ignore[index]
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=load_spec,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
block_hashes=request_real.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
@@ -317,33 +178,14 @@ class MooncakeStoreConnectorV1Scheduler:
meta.add_request(req_meta)
cached_reqs = scheduler_output.scheduled_cached_reqs
if isinstance(cached_reqs, list) and not force_skip_save:
for i, req in enumerate(cached_reqs):
request_tracker = self._request_trackers[req.req_id]
request_tracker.update(req.new_token_ids, req.new_block_ids)
last_chunk_tokens_num = ((len(req.prompt_token_ids) //
self._block_size * self._block_size)
if self._discard_partial_chunks else
len(req.prompt_token_ids))
req_meta = ReqMeta.from_request_tracker(
request_tracker,
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
meta.add_request(req_meta)
elif not force_skip_save:
if not force_skip_save:
for i, req_id in enumerate(cached_reqs.req_ids):
request_tracker = self._request_trackers[req_id]
num_new_tokens = scheduler_output.num_scheduled_tokens[req_id]
req_tuple = self._unfinished_requests.get(req_id)
if req_tuple:
request = req_tuple[0]
num_current_tokens = len(request_tracker.token_ids)
num_current_tokens = request_tracker.token_len
new_token_ids = request.all_token_ids[
num_current_tokens:num_current_tokens + num_new_tokens]
else:
@@ -355,8 +197,7 @@ class MooncakeStoreConnectorV1Scheduler:
continue
request_tracker.update(new_token_ids, new_block_ids)
# decode not save
if len(request_tracker.token_ids) > len(
request.prompt_token_ids):
if request_tracker.token_len > len(request.prompt_token_ids):
continue
last_chunk_tokens_num = ((len(request.prompt_token_ids) //
@@ -368,7 +209,8 @@ class MooncakeStoreConnectorV1Scheduler:
self._block_size,
load_spec=None,
skip_save=force_skip_save,
is_last_chunk=len(request_tracker.token_ids)
block_hashes=request.block_hashes,
is_last_chunk=request_tracker.token_len
>= last_chunk_tokens_num,
discard_partial_chunks=self._discard_partial_chunks,
)
@@ -384,15 +226,14 @@ class MooncakeStoreConnectorV1Scheduler:
load_spec = self.load_specs.pop(request_id, None)
if not load_spec:
continue
num_tokens_to_compute = load_spec.mooncake_cached_tokens
num_tokens_to_compute = load_spec.kvpool_cached_tokens
if (num_tokens_to_compute % self._block_size
!= 0) and (num_tokens_to_compute
== len(request.prompt_token_ids) - 1):
num_tokens_to_compute = num_tokens_to_compute + 1
request_tracker = RequestTracker(
req_id=request_id,
token_ids=request.prompt_token_ids[:num_tokens_to_compute].
copy(),
token_len=num_tokens_to_compute,
allocated_block_ids=block_ids,
num_saved_tokens=0,
)
@@ -404,6 +245,7 @@ class MooncakeStoreConnectorV1Scheduler:
self._block_size,
load_spec=load_spec,
skip_save=None,
block_hashes=request.block_hashes,
discard_partial_chunks=self._discard_partial_chunks,
)
if req_meta is not None:
@@ -431,12 +273,12 @@ class MooncakeStoreConnectorV1Scheduler:
return delay_free_blocks, None
class MooncakeLookupClient:
class LookupKeyClient:
def __init__(self, vllm_config: "VllmConfig"):
self.encoder = MsgpackEncoder()
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
socket_path = get_zmq_rpc_path_lookup(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
@@ -444,9 +286,12 @@ class MooncakeLookupClient:
bind=False,
)
def lookup(self, token_ids: torch.Tensor) -> int:
request = self.encoder.encode(token_ids)
self.socket.send_multipart(request, copy=False)
def lookup(self, token_len: int, block_hashes: list[BlockHash]) -> int:
hash_strs = [h.hex() for h in block_hashes]
hash_frames = self.encoder.encode(hash_strs)
token_len_bytes = token_len.to_bytes(4, byteorder="big")
all_frames = [token_len_bytes] + list(hash_frames)
self.socket.send_multipart(all_frames, copy=False)
resp = self.socket.recv()
result = int.from_bytes(resp, "big")
return result
@@ -455,39 +300,19 @@ class MooncakeLookupClient:
self.socket.close(linger=0)
class MooncakeLookupServer:
def __init__(
self,
mooncake_engine: MooncakeEngine,
vllm_config: "VllmConfig",
use_layerwise: bool,
):
self.decoder = MsgpackDecoder(torch.Tensor)
self.ctx = zmq.Context() # type: ignore[attr-defined]
socket_path = get_zmq_rpc_path_mooncake(vllm_config)
self.socket = make_zmq_socket(
self.ctx,
socket_path,
zmq.REP, # type: ignore[attr-defined]
bind=True,
)
self.mooncake_engine = mooncake_engine
self.running = True
def process_request():
while self.running:
frames = self.socket.recv_multipart(copy=False)
token_ids = self.decoder.decode(frames)
result = self.mooncake_engine.lookup_scheduler(
token_ids, use_layerwise)
response = result.to_bytes(4, "big")
self.socket.send(response)
self.thread = threading.Thread(target=process_request, daemon=True)
self.thread.start()
def close(self):
self.socket.close(linger=0)
# TODO: close the thread!
def get_zmq_rpc_path_lookup(
vllm_config: Optional["VllmConfig"] = None, ) -> str:
base_url = envs.VLLM_RPC_BASE_PATH
# Default to 0 if not configured
rpc_port = 0
if vllm_config is not None:
extra_config = vllm_config.kv_transfer_config.kv_connector_extra_config
if "lookup_rpc_port" in extra_config:
rpc_port = extra_config["lookup_rpc_port"]
elif "mooncake_rpc_port" in extra_config:
rpc_port = extra_config["mooncake_rpc_port"]
logger.warning(
"It is recommended to use the lookup_rpc_port, as the mooncake_rpc_port will be removed in the future."
)
logger.debug("Base URL: %s, RPC Port: %s", base_url, rpc_port)
return f"ipc://{base_url}/lookup_rpc_port_{rpc_port}"

View File

@@ -1,25 +1,33 @@
# Standard
import math
import threading
import time
from typing import Generator, List, Optional, Union
from typing import Dict, Generator, Optional, Type
# Third Party
import torch
from vllm.config import VllmConfig
from vllm.utils import logger
from vllm.utils.torch_utils import get_kv_cache_torch_dtype
from vllm.v1.core.kv_cache_utils import BlockHash
from vllm_ascend.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta, MooncakeConnectorMetadata,
MooncakeEngineMetadata)
from vllm_ascend.distributed.mooncake.kv_transfer import (
from vllm_ascend.distributed.kvpool.backend.backend import Backend
from vllm_ascend.distributed.kvpool.backend.memcache_backend import \
MemcacheBackend
from vllm_ascend.distributed.kvpool.backend.mooncake_backend import \
MooncakeBackend
from vllm_ascend.distributed.kvpool.config_data import (
AscendConnectorMetadata, ChunkedTokenDatabase, KeyMetadata,
LasyerMultiBlockReqMeta)
from vllm_ascend.distributed.kvpool.kv_transfer import (
KVCacheStoreLayerRecvingThread, KVCacheStoreLayerSendingThread,
KVCacheStoreRecvingThread, KVCacheStoreSendingThread, KVTransferThread)
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
backend_map: Dict[str, Type[Backend]] = {
"mooncake": MooncakeBackend,
"memcache": MemcacheBackend,
}
class MooncakeEngine:
class KVPoolWorker:
#The main class for the cache engine.
def __init__(
@@ -29,6 +37,7 @@ class MooncakeEngine:
):
model_config = vllm_config.model_config
parallel_config = vllm_config.parallel_config
self.dp_rank = parallel_config.data_parallel_rank
self.use_mla = False
if (hasattr(model_config, "use_mla")
and isinstance(model_config.use_mla, bool)
@@ -40,37 +49,37 @@ class MooncakeEngine:
self.kv_role = vllm_config.kv_transfer_config.kv_role
self.load_async = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"load_async", False)
self.register_buffer = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"register_buffer", False)
self.backend = vllm_config.kv_transfer_config.kv_connector_extra_config.get(
"backend", "mooncake")
self.block_size = vllm_config.cache_config.block_size
self.current_layer = 0
# self.use_mla = first_kv_cache_tuple[0].size(
# -1) != first_kv_cache_tuple[1].size(-1)
self.num_layers = model_config.get_num_layers(parallel_config)
self.block_size = vllm_config.cache_config.block_size
num_kv_head = model_config.get_num_kv_heads(parallel_config)
head_size = model_config.get_head_size()
kv_dtype = get_kv_cache_torch_dtype(
vllm_config.cache_config.cache_dtype, model_config.dtype)
self.hidden_dim_size = num_kv_head * head_size
if self.use_mla:
kv_shape = (self.num_layers, 1, self.block_size, 1, head_size)
self.num_kv_head = 1
else:
kv_shape = (self.num_layers, 2, self.block_size, num_kv_head,
head_size)
self.metadata = MooncakeEngineMetadata(
self.num_kv_head = model_config.get_total_num_kv_heads()
if self.num_kv_head < self.tp_size:
self.put_step = self.tp_size // self.num_kv_head
self.head_or_tp_rank = self.tp_rank // self.put_step
else:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.metadata = KeyMetadata(
model_config.model,
parallel_config.world_size,
parallel_config.rank,
kv_dtype,
kv_shape,
self.block_size,
self.use_mla,
self.head_or_tp_rank,
)
self.token_database = ChunkedTokenDatabase(self.metadata)
self.token_database = ChunkedTokenDatabase(self.metadata,
self.block_size,
self.use_mla)
self.m_store = Mooncakestore(parallel_config)
real_backend = backend_map.get(self.backend.lower())
self.m_store = real_backend( # type: ignore[misc]
parallel_config)
self.kv_send_thread: Optional[KVTransferThread] = None
self.kv_recv_thread: Optional[KVTransferThread] = None
@@ -108,94 +117,83 @@ class MooncakeEngine:
self.kv_caches = kv_caches
self.kv_caches_base_addr = []
ptrs = []
lengths = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[i % 2]
self._register(base_addr, region_len)
region_len = self.num_blocks * self.block_len[i % 2]
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [cache_or_caches
] if self.use_mla else cache_or_caches
for cache in cache_list:
base_addr = cache.data_ptr()
self.kv_caches_base_addr.append(base_addr)
if self.register_buffer:
region_len = self.num_blocks * self.block_len[0]
self._register(base_addr, region_len)
region_len = self.num_blocks * self.block_len[0]
ptrs.append(base_addr)
lengths.append(region_len)
self.m_store.register_buffer(ptrs, lengths)
self.token_database.set_kv_caches_base_addr(self.kv_caches_base_addr)
self.token_database.set_block_len(self.block_len)
if self.use_layerwise:
self.get_event = threading.Event()
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreLayerSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending,
self.num_layers)
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending, self.num_layers)
self.kv_send_thread.start()
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreLayerRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database, self.block_len,
self.block_size, ready_event, self.get_event)
self.m_store, self.token_database, self.tp_rank, ready_event,
self.get_event)
self.kv_recv_thread.start()
ready_event.wait()
else:
if self.kv_role in ['kv_producer', 'kv_both']:
ready_event_sending = threading.Event()
self.kv_send_thread = KVCacheStoreSendingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event_sending)
self.m_store, self.token_database, self.tp_rank,
self.put_step, ready_event_sending)
self.kv_send_thread.start()
if self.load_async:
ready_event = threading.Event()
self.kv_recv_thread = KVCacheStoreRecvingThread(
self.tp_rank, self.tp_size, self.m_store,
self.kv_caches_base_addr, self.token_database,
self.block_len, self.block_size, ready_event)
self.m_store, self.token_database, self.tp_rank,
ready_event)
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
try:
self.m_store.register_buffer(ptr, length)
except Exception as e:
raise RuntimeError(
f"Mooncake memory registration failed. Error is: {e}")
def start_load_kv(self, metadata: MooncakeConnectorMetadata):
def start_load_kv(self, metadata: AscendConnectorMetadata):
self.current_layer = 0
self.layerwise_retrievers = []
for request in metadata.requests:
load_spec = request.load_spec
if load_spec is None or not load_spec.can_load: #load =0
continue
tokens = request.token_ids
token_len = request.token_len_chunk
req_id = request.req_id
if (load_spec.mooncake_cached_tokens % self.block_size
!= 0) and (load_spec.mooncake_cached_tokens
== tokens.shape[0] - 1):
tokens = tokens[:request.load_spec.mooncake_cached_tokens + 1]
if (load_spec.kvpool_cached_tokens % self.block_size
!= 0) and (load_spec.kvpool_cached_tokens
== token_len - 1):
token_len = request.load_spec.kvpool_cached_tokens + 1
else:
tokens = tokens[:request.load_spec.mooncake_cached_tokens]
masked_token_count = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
token_mask = torch.ones_like(tokens, dtype=torch.bool)
token_mask[:masked_token_count] = False
token_len = request.load_spec.kvpool_cached_tokens
mask_num = (request.load_spec.vllm_cached_tokens //
self.block_size * self.block_size)
if self.use_layerwise:
layerwise_retriever = self.retrieve_layer(
req_id,
tokens,
token_len,
request.block_ids,
token_mask,
request.block_hashes,
mask_num,
)
next(layerwise_retriever) # first layer load
self.layerwise_retrievers.append(layerwise_retriever)
@@ -203,102 +201,84 @@ class MooncakeEngine:
if self.load_async:
self.kv_recv_thread.add_request( # type: ignore[union-attr]
req_id,
tokens,
token_len,
request.block_ids,
token_mask,
request.block_hashes,
mask_num,
)
else:
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, block_id = self.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list,
blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, token_mask):
addr, size, _ = self.prepare_value(
start, end, request.block_ids)
self.m_store.get(key, addr, size)
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
addr_list = []
size_list = []
key_list = []
for start, end, key in self.token_database.process_tokens(
token_len, request.block_hashes, mask_num):
addr, size, _ = self.token_database.prepare_value(
start, end, request.block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
key_list_c = key_list[self.tp_rank % len(
key_list):] + key_list[:self.tp_rank % len(key_list)]
addr_list_c = addr_list[self.tp_rank %
len(addr_list
):] + addr_list[:self.tp_rank %
len(addr_list)]
size_list_c = size_list[self.tp_rank %
len(size_list
):] + size_list[:self.tp_rank %
len(size_list)]
self.m_store.get(key_list_c, addr_list_c, size_list_c)
def wait_for_layer_load(self) -> None:
"""MooncakeConnector does not do layerwise saving."""
for layerwise_retriever in self.layerwise_retrievers:
ret_token_mask = next(layerwise_retriever)
if self.current_layer == self.num_layers - 1:
assert ret_token_mask is not None
num_retrieved_tokens = ret_token_mask.sum().item()
logger.info(f"Retrieved {num_retrieved_tokens} tokens")
logger.debug(f"Retrieved {num_retrieved_tokens} tokens")
def save_kv_layer(self,
connector_metadata: MooncakeConnectorMetadata) -> None:
"""MooncakeConnector does not save explicitly."""
connector_metadata: AscendConnectorMetadata) -> None:
if self.current_layer == 0:
self.layerwise_storers = []
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
can_save = request.can_save
if can_save is None or not can_save:
continue
token_ids = request.token_ids
token_len = request.token_len_chunk
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
# TODO: whether need to remov saveThread
# no lookup, skipmask
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
skip_leading_tokens = self.lookup(token_len,
request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
layerwise_storer = self.store_layer(
req_id,
token_ids,
mask=store_mask,
token_len,
block_hashes=request.block_hashes,
mask_num=mask_num,
block_ids=request.block_ids,
is_last_chunk=request.is_last_chunk,
)
self.layerwise_storers.append(layerwise_storer)
for layerwise_storer in self.layerwise_storers:
@@ -306,59 +286,53 @@ class MooncakeEngine:
next(layerwise_storer)
except Exception:
raise
self.current_layer = self.current_layer + 1
self.current_layer = self.current_layer + 1
def wait_for_save(self, connector_metadata: MooncakeConnectorMetadata):
"""MooncakeConnector does not save explicitly."""
def wait_for_save(self, connector_metadata: AscendConnectorMetadata):
for request in connector_metadata.requests:
save_spec = request.save_spec
if save_spec is None or not save_spec.can_save:
can_save = request.can_save
if can_save is None or not can_save:
continue
token_ids = request.token_ids
token_len = request.token_len_chunk
req_id = request.req_id
assert isinstance(token_ids, torch.Tensor)
assert token_ids.is_cpu
skip_leading_tokens = max(
self.lookup(token_ids, self.use_layerwise),
save_spec.skip_leading_tokens,
)
if skip_leading_tokens == len(token_ids):
skip_leading_tokens = self.lookup(token_len, request.block_hashes,
self.use_layerwise)
if skip_leading_tokens == token_len:
if request.is_last_chunk:
self.kv_send_thread.set_finished_request( # type: ignore[union-attr]
req_id)
continue # skip this request
skip_leading_tokens = (skip_leading_tokens // self.block_size *
self.block_size)
store_mask = torch.ones_like(token_ids, dtype=torch.bool)
store_mask[:skip_leading_tokens] = False
mask_num = (skip_leading_tokens // self.block_size *
self.block_size)
logger.info(
"Storing KV cache for %d out of %d tokens "
"(skip_leading_tokens=%d) for request %s",
len(token_ids) - skip_leading_tokens,
len(token_ids),
token_len - skip_leading_tokens,
token_len,
skip_leading_tokens,
request.req_id,
)
self.kv_send_thread.add_request( # type: ignore[union-attr]
req_id,
token_ids,
token_len,
request.block_ids,
store_mask,
request.block_hashes,
mask_num,
request.is_last_chunk,
)
def retrieve_layer(
self,
req_id: str,
tokens: torch.Tensor,
token_len: int,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
block_hashes: list[BlockHash],
mask_num: int = 0,
) -> Generator[Optional[torch.Tensor], None, None]:
"""
Retrieve the KV cache in a layerwise manner.
@@ -376,20 +350,16 @@ class MooncakeEngine:
be the boolean mask indicating which tokens are retrieved and will
only be returned in the last iteration.
"""
num_required_tokens = token_len - mask_num
if mask is not None:
num_required_tokens = torch.sum(mask).item()
else:
num_required_tokens = len(tokens)
ret_mask = torch.zeros_like(tokens, dtype=torch.bool, device="cpu")
ret_mask = torch.zeros(token_len, dtype=torch.bool, device="cpu")
starts = []
ends = []
keys = []
first_flag = True
for start, end, key in self.token_database.process_tokens(
tokens, mask):
token_len, block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -421,16 +391,18 @@ class MooncakeEngine:
retrieved_tokens = torch.sum(ret_mask)
logger.debug(f"Retrieved {retrieved_tokens} "
f"out of {num_required_tokens} "
f"out of total {len(tokens)} tokens")
f"out of total {token_len} tokens")
yield ret_mask
def store_layer(
self,
req_id: str,
tokens: torch.Tensor,
token_len: int,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
block_hashes: list[BlockHash],
is_last_chunk: bool,
mask_num: int = 0,
) -> Generator[None, None, None]:
"""
Store the KV cache in a layerwise manner.
@@ -452,17 +424,13 @@ class MooncakeEngine:
storage backends. In the last iteration, it puts the memory objects
of the last layer to the storage backends.
"""
if mask is not None:
num_stored_tokens = torch.sum(mask).item()
else:
num_stored_tokens = len(tokens)
num_stored_tokens = token_len - mask_num
starts = []
ends = []
keys = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
token_len, block_hashes, mask_num):
keys_multi_layer = key.split_layers(self.num_layers)
starts.append(start)
ends.append(end)
@@ -473,7 +441,7 @@ class MooncakeEngine:
for layer_id, keys_multi_chunk in enumerate(keys):
req_meta = LasyerMultiBlockReqMeta(req_id, keys_multi_chunk,
starts, ends, block_ids,
layer_id)
layer_id, is_last_chunk)
self.kv_send_thread.add_request( # type: ignore[union-attr, call-arg]
req_meta) # type: ignore[union-attr, call-arg, arg-type]
yield
@@ -481,7 +449,7 @@ class MooncakeEngine:
for layer_id in range(self.num_layers):
yield
logger.debug(
f"Stored {num_stored_tokens} out of total {len(tokens)} tokens")
f"Stored {num_stored_tokens} out of total {token_len} tokens")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
@@ -500,13 +468,10 @@ class MooncakeEngine:
self.tp_rank)
return done_sending, done_recving
def wait_layer_transfer_finish(self):
time.sleep(10)
pass
def lookup(
self,
tokens: Union[torch.Tensor, List[int]],
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
@@ -517,34 +482,24 @@ class MooncakeEngine:
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
else:
keys.append(key.to_string())
starts.append(start)
res = self.m_store.batch_exists(
keys) # type: ignore[assignment]
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
starts.append(start)
res = self.m_store.exists(keys) # type: ignore[assignment]
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
for index, value in enumerate(res): # type: ignore[arg-type]
if value != 1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
@@ -553,7 +508,8 @@ class MooncakeEngine:
def lookup_scheduler(
self,
tokens: Union[torch.Tensor, List[int]],
token_len: int,
block_hashes: list[BlockHash],
use_layerwise: bool,
) -> int:
"""
@@ -564,59 +520,59 @@ class MooncakeEngine:
end = 0
keys = []
try:
if use_layerwise:
for start, end, key in self.token_database.process_tokens(
tokens):
starts = []
for start, end, key in self.token_database.process_tokens(
token_len, block_hashes):
if use_layerwise:
keys_multi_layer = key.split_layers(self.num_layers)
for item in keys_multi_layer:
keys.append(item.to_string())
# batch is_exists
ress = self.m_store.batch_exists(keys)
res = 1
for value in ress:
if value != 1:
res = 0
break
if res == 1:
continue
else:
return start
else:
starts = []
for start, end, key in self.token_database.process_tokens(
tokens):
else:
keys.append(key.to_string())
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, self.tp_size):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@0", f"@{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.batch_exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
multi_tp_values = [
res[i * num_block:(i + 1) *
num_block] # type: ignore[index]
for i in range(self.tp_size)
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
starts.append(start)
multi_tp_keys = keys[:]
for i in range(1, min(self.tp_size, self.num_kv_head)):
for item in keys:
new_str = item.replace( # type: ignore[attr-defined]
"@head_or_tp_rank:0", f"@head_or_tp_rank:{i}", 1)
multi_tp_keys.append(new_str)
res = self.m_store.exists(
multi_tp_keys) # type: ignore[assignment]
num_block = len(keys)
if use_layerwise:
res = self.check_all_layers_exists(res, self.num_layers)
num_block = len(keys) // self.num_layers
multi_tp_values = [
res[i * num_block:(i + 1) * num_block] # type: ignore[index]
for i in range(min(self.tp_size, self.num_kv_head))
]
index = self.find_min_first_non_one_index(multi_tp_values)
if index != -1:
return starts[index]
# all tokens where found, return the maximal end
except Exception as e:
logger.error(f"Remote connection failed in contains: {e}")
return start
return end
def check_all_layers_exists(self, res: list[int],
num_layers: int) -> list[int]:
total_chunks = len(res) // num_layers
result = []
for chunk_idx in range(total_chunks):
start = chunk_idx * num_layers
end = start + num_layers
chunk = res[start:end]
result.append(1 if all(x == 1 for x in chunk) else 0)
return result
def find_min_first_non_one_index(self, arr):
try:
return min(idx for row in arr for idx, val in enumerate(row)
if val != 1)
except ValueError:
return -1
def close(self) -> None:
"""Close the cache engine and free all the resources"""
self.m_store.close()

View File

@@ -28,6 +28,7 @@ from vllm.forward_context import ForwardContext
from vllm.utils import logger
from vllm.v1.core.kv_cache_manager import KVCacheBlocks
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import Request, RequestStatus
import vllm_ascend.envs as envs_ascend
@@ -100,7 +101,10 @@ class LLMDataDistCMgrConnectorMetadata(KVConnectorMetadata):
class LLMDataDistCMgrConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
if role == KVConnectorRole.SCHEDULER:

View File

@@ -1,534 +0,0 @@
import array
import hashlib
import json
import os
import re
from dataclasses import dataclass
from typing import Iterable, List, Optional, Tuple, Union
import torch
from vllm.distributed.kv_transfer.kv_connector.v1.base import \
KVConnectorMetadata
from vllm.utils import logger
from vllm.utils.math_utils import cdiv
from vllm.v1.core.sched.output import NewRequestData
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
@dataclass
class MooncakeEngineMetadata:
"""name of the LLM model"""
model_name: str
""" world size when running under a distributed setting """
world_size: int
""" worker id when running under a distributed setting """
worker_id: int
""" the format of kv tensors """
kv_dtype: torch.dtype
""" the shape of kv tensors """
""" (num_layer, 2, metadata.block_size, num_kv_head, head_size) """
kv_shape: tuple[int, int, int, int, int]
block_size: int = 128
""" whether use MLA"""
use_mla: bool = False
@dataclass(order=True)
class MooncakeEngineKey:
model_name: str
world_size: int
worker_id: int
chunk_hash: str
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}")
def split_layers(self, num_layers: int) -> List["LayerMooncakeEngineKey"]:
"""Split the key into multiple keys for each layer"""
keys = []
for layer_id in range(num_layers):
keys.append(
LayerMooncakeEngineKey(
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
layer_id,
))
return keys
def to_dict(self):
# Note(Kuntai): this is used for serializing CacheEngineKey via msgpack.
return {
"__type__": "CacheEngineKey",
"model_name": self.model_name,
"world_size": self.world_size,
"worker_id": self.worker_id,
"chunk_hash": self.chunk_hash,
}
@staticmethod
def from_dict(d):
return MooncakeEngineKey(
model_name=d["model_name"],
world_size=d["world_size"],
worker_id=d["worker_id"],
chunk_hash=d["chunk_hash"],
)
@dataclass(order=True)
class LayerMooncakeEngineKey(MooncakeEngineKey):
"""A key for the layer cache engine"""
layer_id: int
def __hash__(self):
return hash((
self.model_name,
self.world_size,
self.worker_id,
self.chunk_hash,
self.layer_id,
))
def to_string(self):
return (f"{self.model_name}@{self.world_size}"
f"@{self.worker_id}@{self.chunk_hash}@{self.layer_id}")
class ChunkedTokenDatabase():
def __init__(
self,
metadata: MooncakeEngineMetadata,
):
self.metadata = metadata
def _make_key_by_hash(self,
chunk_hash: str,
layer_id: Optional[int] = None):
assert self.metadata is not None
return MooncakeEngineKey(
self.metadata.model_name,
self.metadata.world_size,
self.metadata.worker_id,
chunk_hash,
)
def _hash(
self,
tokens: Union[torch.Tensor, List[int]],
prefix_hash: str,
) -> str:
# TODO: change it to a more efficient hash function
if isinstance(tokens, torch.Tensor):
tokens_bytes = tokens.cpu().to(torch.uint32).numpy().tobytes()
elif isinstance(tokens, list):
tokens_bytes = array.array("I", tokens).tobytes()
return hashlib.sha256(prefix_hash.encode("ascii") +
tokens_bytes).hexdigest()
def _chunk_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
) -> Iterable[Union[torch.Tensor, List[int]]]:
"""
Chunk the tokens into chunks of size self.metadata.block_size.
:param tokens: the input tokens, with shape [seq_len]
device: the target device after chunking
:return: a generator of chunks of tokens, each with
shape [metadata.block_size]
"""
for i in range(0, len(tokens), self.metadata.block_size):
yield tokens[i:i + self.metadata.block_size]
def _prefix_hash(
self,
token_chunks: Iterable[Union[torch.Tensor, List[int]]],
) -> Iterable[str]:
prefix_hash = ''
for token_chunk in token_chunks:
prefix_hash = self._hash(token_chunk, prefix_hash)
yield prefix_hash
def process_tokens(
self,
tokens: Union[torch.Tensor, List[int]],
mask: Optional[torch.Tensor] = None,
) -> Iterable[Tuple[int, int, MooncakeEngineKey]]:
"""Process the tokens and return the corresponding cache engine keys.
:param Union[torch.Tensor, List[int]] tokens: The tokens to process.
:param Optional[torch.Tensor] mask: The mask for the tokens. Should
have the same length as tokens. And the mask should ALWAYS be like
FFFFFTTTTTTT, where True means the tokens needs to be matched,
and the Falses will ALWAYS be at the PREFIX of the tensor.
:param bool make_key: Whether to make the cache engine key or not.
If False, the hash value will be returned instead.
:returns: A iterable of tuples with three elements. The first element
is the start index of the tokens for the key. The second element
is the end index of the tokens for the key. The third element is
the cache engine key (or hash) for the tokens.
:raises: ValueError if the number of Falses in the mask is not a
multiple of the chunk size.
"""
if mask is not None:
num_falses = mask.numel() - mask.long().sum().item()
else:
num_falses = 0
if num_falses % self.metadata.block_size != 0:
raise ValueError(
"The number of Falses in the mask is not a multiple of the chunk size."
)
total_len = len(tokens)
token_chunks = self._chunk_tokens(tokens)
prefix_hashes = self._prefix_hash(token_chunks)
start_idx = 0
for chunk_id, hash_val in enumerate(prefix_hashes):
start_idx = chunk_id * self.metadata.block_size
end_idx = min(start_idx + self.metadata.block_size, total_len)
if start_idx < num_falses:
continue
else:
yield start_idx, end_idx, self._make_key_by_hash(hash_val)
@dataclass
class LoadSpec:
# Number of tokens cached in vLLM
vllm_cached_tokens: int
# Number of tokens that are cached in mooncake
mooncake_cached_tokens: int
# Whether the scheduler allow us to load the tokens
can_load: bool
@dataclass
class SaveSpec:
# Skip already saved tokens
skip_leading_tokens: int
# Whether the scheduler allow us to save the tokens
can_save: bool
@dataclass
class RequestTracker:
# Request id
req_id: str
# The token ids that has been scheduled so far
token_ids: list[int]
# The block ids that has been allocated so far
# NOTE: allocated blocks could be more than the number of tokens
# FIXME: need to check whether the block ids will be changed after
# preemption
allocated_block_ids: list[int]
# The number of tokens that has been savd
num_saved_tokens: int = 0
@staticmethod
def from_new_request(
new_request: "NewRequestData",
num_tokens_to_compute: int,
) -> "RequestTracker":
"""Create the request tracker from a new request.
Args:
new_request (NewRequestData): the new request data.
num_tokens_to_compute (int): the number of tokens that will
be 'computed', including the `num_computed_tokens` (vLLM's
local cache hit) and new tokens that will be scheduled.
"""
# vLLM 0.9.0 update: request.block_ids changed from list[int] to
# list[list[int]]
# Need to check the type of request.block_ids
unfolded_block_ids = []
if not isinstance(new_request.block_ids[0], list):
unfolded_block_ids = new_request.block_ids.copy()
else:
unfolded_block_ids = new_request.block_ids[0].copy()
return RequestTracker(
req_id=new_request.req_id,
token_ids=new_request.prompt_token_ids[:num_tokens_to_compute].
copy(),
allocated_block_ids=unfolded_block_ids,
num_saved_tokens=0,
)
def update(
self,
new_token_ids: list[int],
new_block_ids: Union[tuple[list[int], ...], list[int]],
) -> None:
"""Update the request tracker when a running request is
scheduled again
"""
self.token_ids.extend(new_token_ids)
if len(new_block_ids) == 0:
new_block_ids = []
elif isinstance(new_block_ids, tuple):
new_block_ids = new_block_ids[0]
elif isinstance(new_block_ids, list):
pass
else:
raise ValueError(
f"Unsupported new_block_ids type {type(new_block_ids)}")
self.allocated_block_ids.extend(new_block_ids)
@dataclass
class ReqMeta:
# Request id
req_id: str
# Request tokens
token_ids: torch.Tensor
block_ids: list[int]
# # Slot mapping if exchange for block_id
# slot_mapping: torch.Tensor
# Skip save or not
save_spec: Optional[SaveSpec] = None
# load_spec
load_spec: Optional[LoadSpec] = None
is_last_chunk: Optional[bool] = None
@staticmethod
def from_request_tracker(
tracker: RequestTracker,
block_size: int,
load_spec: Optional[LoadSpec] = None,
skip_save: Optional[bool] = False,
is_last_chunk: Optional[bool] = None,
discard_partial_chunks: bool = True,
) -> Optional["ReqMeta"]:
"""Create the request metadata from a request tracker.
Args:
tracker (RequestTracker): the request tracker.
block_size (int): the block size in vLLM.
load_spec (Optional[LoadSpec]): the load spec for KV cache loading.
skip_save (bool): whether to skip the save operation.
discard_partial_chunks (bool): whether to discard partial chunks.
Returns:
the request metadata if we need to perform load/save
operations, None otherwise.
"""
input_token_ids = tracker.token_ids
input_token_len = len(input_token_ids)
# For save operation: do not save if the following condition is met
# 1. has already been saved before (num_saved_tokens > 0)
# 2. number of unsaved tokens is not reached the chunk boundary
skip_leading_tokens = tracker.num_saved_tokens
chunk_boundary = (cdiv(tracker.num_saved_tokens + 1, block_size) *
block_size if discard_partial_chunks else 0)
# Calculate number of tokens to save based on discard_partial_chunks
# setting
num_tokens_to_save = ((input_token_len // block_size * block_size)
if discard_partial_chunks else input_token_len)
skip_save = skip_save or num_tokens_to_save < chunk_boundary
if skip_save and load_spec is None:
return None
# If we need to save, update the number of saved tokens
if not skip_save:
tracker.num_saved_tokens = num_tokens_to_save
save_spec = SaveSpec(skip_leading_tokens, not skip_save)
# Calculate the token ids and slot mappings for load and save
# OPTIMIZATION: pre-allocate the buffer for token ids and block ids
token_ids = torch.tensor(input_token_ids)[:num_tokens_to_save]
# # For load operation: check whether the request is scheduled to load
if load_spec is not None and load_spec.can_load:
logger.debug(
"Scheduled to load %d tokens for request %s",
load_spec.mooncake_cached_tokens,
tracker.req_id,
)
else:
# Do not load if not in `can_load` state
load_spec = None
logger.debug(
f"request:{tracker.req_id}, meta save spec:{save_spec}, meta load spec:{load_spec}"
)
return ReqMeta(
req_id=tracker.req_id,
token_ids=token_ids,
block_ids=tracker.allocated_block_ids,
save_spec=save_spec,
load_spec=load_spec,
is_last_chunk=is_last_chunk,
)
class MooncakeConnectorMetadata(KVConnectorMetadata):
def __init__(self, unfinished_request_ids):
self.requests = []
self.unfinished_request_ids = unfinished_request_ids
def add_request(self, req_meta: ReqMeta) -> None:
"""Add a request to the metadata.
Args:
req_meta (ReqMeta): the request metadata.
"""
self.requests.append(req_meta)
@dataclass
class LasyerMultiBlockReqMeta:
req_id: str
keys: List[LayerMooncakeEngineKey]
starts: List[int]
ends: list[int]
block_ids: list[int]
layer_id: int
@dataclass
class MooncakeStoreConfig:
local_hostname: str
metadata_server: str
global_segment_size: Union[int, str]
local_buffer_size: int
protocol: str
device_name: str
master_server_address: str
use_ascend_direct: bool
@staticmethod
def from_file(file_path: str) -> "MooncakeStoreConfig":
with open(file_path) as file:
config = json.load(file)
return MooncakeStoreConfig(
local_hostname=config.get("local_hostname"),
metadata_server=config.get("metadata_server"),
global_segment_size=_parse_global_segment_size(
config.get("global_segment_size",
DEFAULT_GLOBAL_SEGMENT_SIZE)),
local_buffer_size=(config.get("local_buffer_size",
DEFAULT_LOCAL_BUFFER_SIZE)),
protocol=config.get("protocol", "tcp"),
device_name=config.get("device_name", ""),
master_server_address=config.get("master_server_address"),
use_ascend_direct=config.get("use_ascend_direct", False))
@staticmethod
def load_from_env() -> "MooncakeStoreConfig":
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
if not config_path:
raise ValueError(
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
return MooncakeStoreConfig.from_file(config_path)
def _parse_global_segment_size(value) -> int:
"""
Parse storage size strings with support for units: GB, MB, KB, B
Args:
value: Input value (int, str, or other convertible types)
Returns:
int: Size in bytes
Raises:
ValueError: For invalid format, missing number, or negative values
TypeError: For unsupported input types
"""
if isinstance(value, int):
return value
elif not isinstance(value, str):
try:
return int(value)
except (TypeError, ValueError) as e:
raise TypeError(
f"Unsupported type for global_segment_size: {type(value)}"
) from e
cleaned_input = value.strip().lower()
if not cleaned_input:
raise ValueError("global segment size cannot be empty.")
UNIT_MULTIPLIERS = {
'gb': 1024**3, # 1 GB = 1024^3 bytes
'mb': 1024**2, # 1 MB = 1024^2 bytes
'kb': 1024, # 1 KB = 1024 bytes
'b': 1 # 1 B = 1 byte
}
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
match = re.match(pattern, cleaned_input)
if not match:
raise ValueError(f"Invalid format: '{value}'")
number_str = match.group(1)
unit = match.group(2) or 'b'
multiplier = UNIT_MULTIPLIERS[unit]
return _convert_to_bytes(number_str, multiplier, value)
def _convert_to_bytes(number_str: str, multiplier: int,
original_input: str) -> int:
"""
Convert numeric string to byte count
Args:
number_str: Numeric portion of input
multiplier: Unit conversion factor
original_input: Original input string (for error messages)
Returns:
int: Byte count
Raises:
ValueError: For invalid numbers or negative results
"""
try:
numeric_value = float(number_str)
except ValueError:
raise ValueError(
f"Invalid numeric value '{number_str}' in: '{original_input}'")
# Calculate byte count
try:
byte_count = int(numeric_value * multiplier)
except OverflowError:
raise ValueError(f"Storage size too large: '{original_input}'")
return byte_count

View File

@@ -1,282 +0,0 @@
import queue
import threading
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Optional
import torch
from vllm.utils import logger
from vllm_ascend.distributed.mooncake.config_data import (
ChunkedTokenDatabase, LasyerMultiBlockReqMeta)
from vllm_ascend.distributed.mooncake.mooncake_store import Mooncakestore
class KVTransferThread(threading.Thread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event, name: str):
super().__init__(daemon=True, name=name)
self.tp_rank = tp_rank
self.tp_size = tp_size
self.m_store = m_store
self.ready_event = ready_event
self.kv_caches_base_addr = local_kv_caches_base_addr
self.block_len = block_len
self.token_database = token_database
self.block_size = block_size
self.done_task_lock = threading.Lock()
# TODO(jianzs): find a better way to detect MLA.
self.use_mla = len(block_len) == 2
self.request_queue: queue.Queue[Any] = queue.Queue()
# TODO(jianzs): make this configurable
self.executor = ThreadPoolExecutor(max_workers=32)
self.finished_requests: set[str] = set()
def prepare_value(self, start: int, end: int, block_ids: list[int]):
addr_list = []
size_list = []
block_id = block_ids[start // self.block_size]
for index, base_addr in enumerate(self.kv_caches_base_addr):
block_len = (self.block_len[index % 2]
if self.use_mla else self.block_len[0])
addr = base_addr + block_id * block_len
length = int(block_len / self.block_size * (end - start))
addr_list.append(addr)
size_list.append(length)
return addr_list, size_list, block_id
def prepare_value_layer(self, start: int, end: int, block_ids: list[int],
layer_id: int):
block_id = block_ids[start // self.block_size]
if self.use_mla:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[1]
length_k = int(self.block_len[0] / self.block_size * (end - start))
length_v = int(self.block_len[1] / self.block_size * (end - start))
size_list = [length_k, length_v]
else:
addr_k = self.kv_caches_base_addr[layer_id *
2] + block_id * self.block_len[0]
addr_v = self.kv_caches_base_addr[layer_id * 2 +
1] + block_id * self.block_len[0]
length = int(self.block_len[0] / self.block_size * (end - start))
size_list = [length, length]
addr_list = [addr_k, addr_v]
return addr_list, size_list
def add_request(
self,
req_id: str,
tokens: torch.Tensor,
block_ids: list[int],
mask: Optional[torch.Tensor] = None,
is_last_chunk: Optional[bool] = None,
) -> torch.Tensor:
req = ({
"req_id": req_id,
"tokens": tokens,
"block_ids": block_ids,
"mask": mask,
"is_last_chunk": is_last_chunk,
})
self.request_queue.put(req)
def get_and_clear_finished_requests(self) -> set[str]:
"""
Get and clear the requests that have been completed.
Returns:
A set of request IDs that have been completed.
"""
with self.done_task_lock:
finished_requests = self.finished_requests.copy()
self.finished_requests.clear()
return finished_requests
def set_finished_request(self, req_id):
with self.done_task_lock:
self.finished_requests.add(req_id)
def run(self):
"""Run the thread to handle KV cache transfer requests."""
self.ready_event.set()
while True:
try:
request_data = self.request_queue.get()
if request_data is None:
logger.warning("Received a None request!")
self.request_queue.task_done()
continue
self._handle_request(request_data)
except Exception as e:
logger.error(f"Error in KVCacheTransferThread: {e}")
def _handle_request(self, req_meta: dict[str, Any]):
pass
class KVCacheStoreSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheSendingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
is_last_chunk = req_meta["is_last_chunk"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
torch.npu.current_stream().synchronize()
self.m_store.put_batch(key_list, addr_list, size_list, blockIds)
else:
torch.npu.current_stream().synchronize()
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.put(key, addr, size)
if is_last_chunk:
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreRecvingThread")
def _handle_request(self, req_meta: dict[str, Any]):
tokens = req_meta["tokens"]
mask = req_meta["mask"]
block_ids = req_meta["block_ids"]
req_id = req_meta["req_id"]
if self.m_store.config.use_ascend_direct:
addr_list = []
size_list = []
key_list = []
blockIds = []
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, block_id = self.prepare_value(
start, end, block_ids)
key_list.append(key.to_string())
addr_list.append(addr)
size_list.append(size)
blockIds.append(block_id)
self.m_store.get_batch(key_list, addr_list, size_list, blockIds)
else:
for start, end, key in self.token_database.process_tokens(
tokens, mask):
addr, size, _ = self.prepare_value(start, end, block_ids)
self.m_store.get(key, addr, size)
self.set_finished_request(req_id)
self.request_queue.task_done()
class KVCacheStoreLayerSendingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
num_layers: int):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerSendingThread")
self.final_layer_id = num_layers - 1
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
torch.npu.current_stream().synchronize()
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.put(key, addr, size)
if req_meta.layer_id == self.final_layer_id:
self.set_finished_request(req_meta.req_id)
self.request_queue.task_done()
class KVCacheStoreLayerRecvingThread(KVTransferThread):
def __init__(self, tp_rank: int, tp_size: int, m_store: Mooncakestore,
local_kv_caches_base_addr: list[int],
token_database: ChunkedTokenDatabase, block_len: list[int],
block_size: int, ready_event: threading.Event,
get_event: threading.Event):
super().__init__(tp_rank,
tp_size,
m_store,
local_kv_caches_base_addr,
token_database,
block_len,
block_size,
ready_event,
name="KVCacheStoreLayerRecvingThread")
self.get_event = get_event
def add_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta) -> torch.Tensor:
self.request_queue.put(req_meta)
def _handle_request( # type: ignore[override]
self, req_meta: LasyerMultiBlockReqMeta):
for index, key in enumerate(req_meta.keys):
addr, size = self.prepare_value_layer(req_meta.starts[index],
req_meta.ends[index],
req_meta.block_ids,
req_meta.layer_id)
self.m_store.get(key, addr, size)
self.request_queue.task_done()
self.get_event.set()

View File

@@ -1,127 +0,0 @@
# Standard
import os
# Third Party
from mooncake.store import ReplicateConfig # type: ignore
from vllm.config import ParallelConfig
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
from vllm.utils import logger
from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
from .config_data import MooncakeStoreConfig
METADATA_BYTES_LEN = 24
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
class Mooncakestore():
def __init__(self, parallel_config: ParallelConfig):
try:
from mooncake.store import MooncakeDistributedStore # type: ignore
except ImportError as e:
raise ImportError(
"Please install mooncake by following the instructions at "
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
"to run vLLM with MooncakeConnector.") from e
tp_rank = get_tensor_model_parallel_rank()
tp_size = parallel_config.tensor_parallel_size
dp_rank = parallel_config.data_parallel_rank_local
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
if not all_device_ids:
device_ids_list = list(
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
else:
device_ids_list = list(map(int, all_device_ids.split(',')))
assert len(device_ids_list) > tp_rank
device_id = device_ids_list[tp_rank]
self.config = MooncakeStoreConfig.load_from_env()
self.store = MooncakeDistributedStore()
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
":npu_" + str(device_id)
ret = self.store.setup(local_hostname, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address)
else:
local_hostname = get_ip()
transfer_engine = get_global_te(local_hostname, device_name=None)
self.local_seg = local_hostname + ":" + str(
transfer_engine.get_rpc_port())
ret = self.store.setup(self.local_seg, self.config.metadata_server,
self.config.global_segment_size,
self.config.local_buffer_size,
self.config.protocol,
self.config.device_name,
self.config.master_server_address,
transfer_engine.get_engine())
if ret != 0:
msg = "Initialize mooncake failed."
logger.error(msg)
raise RuntimeError(msg)
def exists(self, key: MooncakeEngineKey) -> bool:
return self.store.is_exist(key.to_string()) == 1
def batch_exists(self, keys: list[str]) -> list[int]:
return self.store.batch_is_exist(keys)
def register_buffer(self, ptr, length):
return self.store.register_buffer(ptr, length)
def get_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
res = self.store.batch_get_into_multi_buffers(
keys, addrs, sizes, True)
for value in res:
if value < 0:
logger.error(f"Failed to get key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to get key {keys}. {e}")
def put_batch(self, keys: list[str], addrs: list[list[int]],
sizes: list[list[int]], block_ids: list[int]):
try:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(
keys, addrs, sizes, config)
for value in res:
if value < 0:
logger.error(f"Failed to put key {keys},res:{res}")
except Exception as e:
logger.error(f"Failed to put key {keys},error:{e}")
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
expect_res = sum(size)
key_str = key.to_string()
try:
res = self.store.batch_get_into_ascend(key_str, addr, size)
if res[0] != expect_res:
logger.error(f"Failed to get key: [{key_str}] .")
except Exception:
logger.error(f"Failed to get key: [{key_str}] .")
return res
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
key_str = key.to_string()
try:
ret = self.store.batch_put_from_ascend(key_str, addr, size)
if ret[0] != 0:
logger.error(f"Failed to put key {key_str}.")
except Exception:
logger.error(f"Failed to put key {key_str}.")
return ret
def close(self):
self.store.close()
logger.info("Closed the mooncake store connection")

View File

@@ -1,38 +0,0 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
_global_te = None
_global_te_lock = threading.Lock()
def get_global_te(hostname: str, device_name: Optional[str]):
try:
ip = ipaddress.ip_address(hostname)
if isinstance(ip, ipaddress.IPv6Address):
raise RuntimeError(
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
)
except ValueError:
pass
global _global_te
if _global_te is None:
with _global_te_lock:
# Double-Checked Locking
if _global_te is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = transfer_engine.initialize(hostname,
"P2PHANDSHAKE",
"ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
_global_te = transfer_engine
return _global_te

View File

@@ -31,11 +31,12 @@ from vllm.distributed.parallel_state import (
get_tensor_model_parallel_rank, get_tp_group)
from vllm.utils import logger
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.request import RequestStatus
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
from vllm_ascend.distributed.utils import get_transfer_timeout_value
from vllm_ascend.utils import prefill_context_parallel_enable
@@ -634,7 +635,10 @@ class MooncakeConnectorMetadata(KVConnectorMetadata):
class MooncakeConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
@@ -944,7 +948,7 @@ class MooncakeConnectorWorker:
else:
hostname = f"{self.side_channel_host}:0:npu_{self.device_id}"
logger.info("Initializing Mooncake work %s", engine_id)
self.engine = get_global_te(hostname, device_name=None)
self.engine = global_te.get_transfer_engine(hostname, device_name=None)
self.te_rpc_port = self.engine.get_rpc_port()
# Background thread for sending or receiving KV caches.
@@ -1054,6 +1058,8 @@ class MooncakeConnectorWorker:
self.kv_caches = kv_caches
kv_caches_base_addr = []
ptrs = []
lengths = []
for cache_or_caches in kv_caches.values():
# Normalize to always be a list of caches
if self.use_mla:
@@ -1061,13 +1067,15 @@ class MooncakeConnectorWorker:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 2]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
ptrs.append(base_addr)
lengths.append(region_len)
elif self.use_sparse:
for i, cache in enumerate(cache_or_caches, 0):
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[i % 3]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
ptrs.append(base_addr)
lengths.append(region_len)
else:
cache_list = [
cache_or_caches
@@ -1076,8 +1084,9 @@ class MooncakeConnectorWorker:
base_addr = cache.data_ptr()
region_len = self.num_blocks * self.block_len[0]
kv_caches_base_addr.append(base_addr)
self._register(base_addr, region_len)
ptrs.append(base_addr)
lengths.append(region_len)
global_te.register_buffer(ptrs, lengths)
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata(
engine_id=self.engine_id,
@@ -1101,14 +1110,6 @@ class MooncakeConnectorWorker:
self.kv_recv_thread.start()
ready_event.wait()
def _register(self, ptr, length):
logger.debug(
"Registering KV cache: ptr=0x%x, length=%d, num_blocks=%d, "
"block_lens=%s", ptr, length, self.num_blocks, self.block_len)
ret_value = self.engine.register_memory(ptr, length)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
def get_finished(self) -> tuple[set[str], set[str]]:
done_sending = (
self.kv_send_thread.

View File

@@ -30,6 +30,7 @@ from vllm.distributed.parallel_state import (get_tensor_model_parallel_rank,
from vllm.utils import logger
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig
import vllm_ascend.envs as envs_ascend
from vllm_ascend.ascend_config import get_ascend_config
@@ -359,7 +360,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
def __init__(self, vllm_config: VllmConfig, role: KVConnectorRole):
def __init__(self,
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()

View File

@@ -0,0 +1,53 @@
import ipaddress
import threading
from typing import Optional
from mooncake.engine import TransferEngine # type: ignore
class GlobalTE():
def __init__(self):
self.transfer_engine = None
self.is_register_buffer: bool = False
self.transfer_engine_lock = threading.Lock()
self.register_buffer_lock = threading.Lock()
def get_transfer_engine(self, hostname: str, device_name: Optional[str]):
try:
ip = ipaddress.ip_address(hostname)
if isinstance(ip, ipaddress.IPv6Address):
raise RuntimeError(
"The backend of mooncake's Ascend Direct Xfer Library currently does not support IPv6."
)
except ValueError:
pass
if self.transfer_engine is None:
with self.transfer_engine_lock:
# Double-Checked Locking
if self.transfer_engine is None:
if TransferEngine is None:
raise RuntimeError("mooncake is not available")
self.transfer_engine = TransferEngine()
device_name = device_name if device_name is not None else ""
ret_value = self.transfer_engine.initialize(
hostname, "P2PHANDSHAKE", "ascend", device_name)
if ret_value != 0:
raise RuntimeError(
f"TransferEngine initialization failed with ret_value: {ret_value}"
)
return self.transfer_engine
def register_buffer(self, ptrs: list[int], sizes: list[int]):
with self.register_buffer_lock:
assert self.transfer_engine is not None, "Transfer engine must be initialized"
if self.is_register_buffer:
return
for ptr, size in zip(ptrs, sizes):
ret_value = self.transfer_engine.register_memory(ptr, size)
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed.")
self.is_register_buffer = True
global_te = GlobalTE()