[Feature][main]reconstruction kvpool connector to ascend connector (#4438)
### What this PR does / why we need it? 1.In short, we renamed the existing MooncakeStoreConnector to AscendStoreConnector and extracted the storage engine interaction logic into a new Backend class. Associated RFC:https://github.com/vllm-project/vllm-ascend/issues/4329 2.Fixed the issue where the number of input parameters for the connector was incorrect, introduced in vllm 0.11.2 ### Does this PR introduce _any_ user-facing change? change MooncakeStoreConnector to AscendStoreConnector ### How was this patch tested? - vLLM version: v0.11.2 --------- Signed-off-by: fems14 <1804143737@qq.com>
This commit is contained in:
188
vllm_ascend/distributed/kvpool/backend/mooncake_backend.py
Normal file
188
vllm_ascend/distributed/kvpool/backend/mooncake_backend.py
Normal file
@@ -0,0 +1,188 @@
|
||||
# Standard
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
|
||||
# Third Party
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.utils import logger
|
||||
from vllm.utils.network_utils import get_ip
|
||||
|
||||
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
||||
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
class MooncakeBackend(Backend):
|
||||
|
||||
def __init__(self, parallel_config: ParallelConfig):
|
||||
try:
|
||||
from mooncake.store import MooncakeDistributedStore # type: ignore
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mooncake by following the instructions at "
|
||||
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
||||
"to run vLLM with MooncakeConnector.") from e
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
self.store = MooncakeDistributedStore()
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
||||
device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
raise RuntimeError(msg)
|
||||
|
||||
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
||||
global_te.register_buffer(ptrs, lengths)
|
||||
|
||||
def exists(self, keys: list[str]) -> list[int]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def put(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys}, res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}, error:{e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: Union[int, str]
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
master_server_address: str
|
||||
use_ascend_direct: bool
|
||||
|
||||
@staticmethod
|
||||
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
||||
with open(file_path) as file:
|
||||
config = json.load(file)
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=(config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
use_ascend_direct=config.get("use_ascend_direct", False))
|
||||
|
||||
@staticmethod
|
||||
def load_from_env() -> "MooncakeStoreConfig":
|
||||
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
except OverflowError:
|
||||
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||
return byte_count
|
||||
Reference in New Issue
Block a user