What this PR does / why we need it?
Improve usability,local_buffer_size support for units: GB, MB, KB, B,
For example, "2GB"
{
"local_hostname": "XXX.XXX.XXX.XXX",
"metadata_server": "P2PHANDSHAKE",
"protocol": "ascend",
"device_name": "",
"use_ascend_direct": true,
"master_server_address": "XXX.XXX.XXX.XXX:50088",
"global_segment_size": 60000000000,
"local_buffer_size": "2GB"
}
Does this PR introduce any user-facing change?
local_buffer_size support for units: GB, MB, KB, B
How was this patch tested?
Mooncake configures local_buffer_size as GB, MB, KB, B
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
Signed-off-by: lty <linxianchong1@huawei.com>
187 lines
6.5 KiB
Python
187 lines
6.5 KiB
Python
# Standard
|
|
import json
|
|
import os
|
|
import re
|
|
from dataclasses import dataclass
|
|
from typing import Union
|
|
|
|
# Third Party
|
|
from vllm.config import ParallelConfig
|
|
from vllm.logger import logger
|
|
from vllm.utils.network_utils import get_ip
|
|
|
|
from vllm_ascend.distributed.kvpool.backend.backend import Backend
|
|
from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
|
|
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
|
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
|
|
|
|
|
class MooncakeBackend(Backend):
|
|
|
|
def __init__(self, parallel_config: ParallelConfig):
|
|
try:
|
|
from mooncake.store import MooncakeDistributedStore # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install mooncake by following the instructions at "
|
|
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
|
"to run vLLM with MooncakeConnector.") from e
|
|
self.config = MooncakeStoreConfig.load_from_env()
|
|
self.store = MooncakeDistributedStore()
|
|
if self.config.protocol == "ascend":
|
|
local_hostname = get_ip()
|
|
transfer_engine = global_te.get_transfer_engine(local_hostname,
|
|
device_name=None)
|
|
self.local_seg = local_hostname + ":" + str(
|
|
transfer_engine.get_rpc_port())
|
|
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
|
self.config.global_segment_size,
|
|
self.config.local_buffer_size,
|
|
self.config.protocol,
|
|
self.config.device_name,
|
|
self.config.master_server_address,
|
|
transfer_engine.get_engine())
|
|
if ret != 0:
|
|
msg = "Initialize mooncake failed."
|
|
logger.error(msg)
|
|
raise RuntimeError(msg)
|
|
|
|
def register_buffer(self, ptrs: list[int], lengths: list[int]):
|
|
global_te.register_buffer(ptrs, lengths)
|
|
|
|
def exists(self, keys: list[str]) -> list[int]:
|
|
return self.store.batch_is_exist(keys)
|
|
|
|
def put(self, keys: list[str], addrs: list[list[int]],
|
|
sizes: list[list[int]]):
|
|
try:
|
|
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
|
|
for value in res:
|
|
if value < 0:
|
|
logger.error(f"Failed to put key {keys},res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to put key {keys},error:{e}")
|
|
|
|
def get(self, keys: list[str], addrs: list[list[int]],
|
|
sizes: list[list[int]]):
|
|
try:
|
|
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
|
|
for value in res:
|
|
if value < 0:
|
|
logger.error(f"Failed to get key {keys}, res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to get key {keys}, error:{e}")
|
|
|
|
|
|
@dataclass
|
|
class MooncakeStoreConfig:
|
|
local_hostname: str
|
|
metadata_server: str
|
|
global_segment_size: Union[int, str]
|
|
local_buffer_size: int
|
|
protocol: str
|
|
device_name: str
|
|
master_server_address: str
|
|
|
|
@staticmethod
|
|
def from_file(file_path: str) -> "MooncakeStoreConfig":
|
|
with open(file_path) as file:
|
|
config = json.load(file)
|
|
return MooncakeStoreConfig(
|
|
local_hostname=config.get("local_hostname"),
|
|
metadata_server=config.get("metadata_server"),
|
|
global_segment_size=_parse_global_segment_size(
|
|
config.get("global_segment_size",
|
|
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
|
local_buffer_size=_parse_global_segment_size(
|
|
config.get("local_buffer_size", DEFAULT_LOCAL_BUFFER_SIZE)),
|
|
protocol=config.get("protocol", "tcp"),
|
|
device_name=config.get("device_name", ""),
|
|
master_server_address=config.get("master_server_address"))
|
|
|
|
@staticmethod
|
|
def load_from_env() -> "MooncakeStoreConfig":
|
|
config_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
|
if not config_path:
|
|
raise ValueError(
|
|
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
|
return MooncakeStoreConfig.from_file(config_path)
|
|
|
|
|
|
def _parse_global_segment_size(value) -> int:
|
|
"""
|
|
Parse storage size strings with support for units: GB, MB, KB, B
|
|
|
|
Args:
|
|
value: Input value (int, str, or other convertible types)
|
|
|
|
Returns:
|
|
int: Size in bytes
|
|
|
|
Raises:
|
|
ValueError: For invalid format, missing number, or negative values
|
|
TypeError: For unsupported input types
|
|
"""
|
|
|
|
if isinstance(value, int):
|
|
return value
|
|
elif not isinstance(value, str):
|
|
try:
|
|
return int(value)
|
|
except (TypeError, ValueError) as e:
|
|
raise TypeError(
|
|
f"Unsupported type for global_segment_size: {type(value)}"
|
|
) from e
|
|
|
|
cleaned_input = value.strip().lower()
|
|
if not cleaned_input:
|
|
raise ValueError("global segment size cannot be empty.")
|
|
|
|
UNIT_MULTIPLIERS = {
|
|
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
|
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
|
'kb': 1024, # 1 KB = 1024 bytes
|
|
'b': 1 # 1 B = 1 byte
|
|
}
|
|
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
|
match = re.match(pattern, cleaned_input)
|
|
|
|
if not match:
|
|
raise ValueError(f"Invalid format: '{value}'")
|
|
|
|
number_str = match.group(1)
|
|
unit = match.group(2) or 'b'
|
|
|
|
multiplier = UNIT_MULTIPLIERS[unit]
|
|
return _convert_to_bytes(number_str, multiplier, value)
|
|
|
|
|
|
def _convert_to_bytes(number_str: str, multiplier: int,
|
|
original_input: str) -> int:
|
|
"""
|
|
Convert numeric string to byte count
|
|
|
|
Args:
|
|
number_str: Numeric portion of input
|
|
multiplier: Unit conversion factor
|
|
original_input: Original input string (for error messages)
|
|
|
|
Returns:
|
|
int: Byte count
|
|
|
|
Raises:
|
|
ValueError: For invalid numbers or negative results
|
|
"""
|
|
try:
|
|
numeric_value = float(number_str)
|
|
except ValueError:
|
|
raise ValueError(
|
|
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
|
# Calculate byte count
|
|
try:
|
|
byte_count = int(numeric_value * multiplier)
|
|
except OverflowError:
|
|
raise ValueError(f"Storage size too large: '{original_input}'")
|
|
return byte_count
|