[Feat](Mooncake) Supports multiple input suffixes for global_segment_size (#3690)
### What this PR does / why we need it?
- global_segment_size and local_buffer_size use constants for unified
management.
- Newly added support for input formats ending with GB, MB, KB, and B,
while being compatible with existing input methods.
### Does this PR introduce _any_ user-facing change?
- Users can use new input methods
- The documentation has also been modified
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: 李子琦 <liziqi_ing@163.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
* Software:
|
* Software:
|
||||||
* Python >= 3.9, < 3.12
|
* Python >= 3.9, < 3.12
|
||||||
* CANN >= 8.3.rc1
|
* CANN >= 8.3.rc1
|
||||||
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
||||||
* vLLM:main branch
|
* vLLM:main branch
|
||||||
* vLLM-Ascend:main branch
|
* vLLM-Ascend:main branch
|
||||||
* Mooncake:main branch
|
* Mooncake:main branch
|
||||||
@@ -41,7 +41,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
|
|||||||
"use_ascend_direct": true,
|
"use_ascend_direct": true,
|
||||||
"alloc_in_same_node": true,
|
"alloc_in_same_node": true,
|
||||||
"master_server_address": "xx.xx.xx.xx:50088",
|
"master_server_address": "xx.xx.xx.xx:50088",
|
||||||
"global_segment_size": 30000000000
|
"global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
68
tests/ut/distributed/mooncake/test_config_data.py
Normal file
68
tests/ut/distributed/mooncake/test_config_data.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
from vllm_ascend.distributed.mooncake.config_data import (
|
||||||
|
_convert_to_bytes, _parse_global_segment_size)
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseGlobalSegmentSize(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_int_input(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size(1024), 1024)
|
||||||
|
self.assertEqual(_parse_global_segment_size(0), 0)
|
||||||
|
|
||||||
|
def test_gb_unit(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3)
|
||||||
|
self.assertEqual(_parse_global_segment_size("1.5GB"),
|
||||||
|
int(1.5 * 1024**3))
|
||||||
|
self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3)
|
||||||
|
|
||||||
|
def test_gb_unit_edge_cases(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_parse_global_segment_size("GB")
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_parse_global_segment_size("abcGB")
|
||||||
|
|
||||||
|
def test_mb_unit(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2)
|
||||||
|
self.assertEqual(_parse_global_segment_size("0.5MB"),
|
||||||
|
int(0.5 * 1024**2))
|
||||||
|
self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2)
|
||||||
|
|
||||||
|
def test_kb_unit(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024)
|
||||||
|
self.assertEqual(_parse_global_segment_size("1.25KB"),
|
||||||
|
int(1.25 * 1024))
|
||||||
|
|
||||||
|
def test_b_unit(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size("4096B"), 4096)
|
||||||
|
self.assertEqual(_parse_global_segment_size("1024b"), 1024)
|
||||||
|
|
||||||
|
def test_no_unit(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size("2048"), 2048)
|
||||||
|
self.assertEqual(_parse_global_segment_size("0"), 0)
|
||||||
|
|
||||||
|
def test_non_string_non_int_input(self):
|
||||||
|
self.assertEqual(_parse_global_segment_size(2048.0), 2048)
|
||||||
|
self.assertEqual(_parse_global_segment_size(True), 1)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
_parse_global_segment_size(None)
|
||||||
|
|
||||||
|
with self.assertRaises(TypeError):
|
||||||
|
_parse_global_segment_size({"size": 1024})
|
||||||
|
|
||||||
|
|
||||||
|
class TestConvertToBytes(unittest.TestCase):
|
||||||
|
|
||||||
|
def test_valid_conversion(self):
|
||||||
|
self.assertEqual(_convert_to_bytes("10", 1, "10"), 10)
|
||||||
|
self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"),
|
||||||
|
int(1.5 * 1024))
|
||||||
|
self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0)
|
||||||
|
|
||||||
|
def test_invalid_numbers(self):
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_convert_to_bytes("abc", 1, "abc")
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
_convert_to_bytes("1.2.3", 1024, "1.2.3KB")
|
||||||
@@ -2,6 +2,7 @@ import array
|
|||||||
import hashlib
|
import hashlib
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
|
import re
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Iterable, List, Optional, Tuple, Union
|
from typing import Iterable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
@@ -11,6 +12,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
|||||||
from vllm.utils import cdiv, logger
|
from vllm.utils import cdiv, logger
|
||||||
from vllm.v1.core.sched.output import NewRequestData
|
from vllm.v1.core.sched.output import NewRequestData
|
||||||
|
|
||||||
|
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||||
|
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MooncakeEngineMetadata:
|
class MooncakeEngineMetadata:
|
||||||
@@ -419,7 +423,7 @@ class LasyerMultiBlockReqMeta:
|
|||||||
class MooncakeStoreConfig:
|
class MooncakeStoreConfig:
|
||||||
local_hostname: str
|
local_hostname: str
|
||||||
metadata_server: str
|
metadata_server: str
|
||||||
global_segment_size: int
|
global_segment_size: Union[int, str]
|
||||||
local_buffer_size: int
|
local_buffer_size: int
|
||||||
protocol: str
|
protocol: str
|
||||||
device_name: str
|
device_name: str
|
||||||
@@ -433,8 +437,11 @@ class MooncakeStoreConfig:
|
|||||||
return MooncakeStoreConfig(
|
return MooncakeStoreConfig(
|
||||||
local_hostname=config.get("local_hostname"),
|
local_hostname=config.get("local_hostname"),
|
||||||
metadata_server=config.get("metadata_server"),
|
metadata_server=config.get("metadata_server"),
|
||||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
global_segment_size=_parse_global_segment_size(
|
||||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
config.get("global_segment_size",
|
||||||
|
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||||
|
local_buffer_size=(config.get("local_buffer_size",
|
||||||
|
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||||
protocol=config.get("protocol", "tcp"),
|
protocol=config.get("protocol", "tcp"),
|
||||||
device_name=config.get("device_name", ""),
|
device_name=config.get("device_name", ""),
|
||||||
master_server_address=config.get("master_server_address"),
|
master_server_address=config.get("master_server_address"),
|
||||||
@@ -446,4 +453,81 @@ class MooncakeStoreConfig:
|
|||||||
if not config_path:
|
if not config_path:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||||
return MooncakeStoreConfig.from_file(config_path)
|
return MooncakeStoreConfig.from_file(config_path)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_global_segment_size(value) -> int:
|
||||||
|
"""
|
||||||
|
Parse storage size strings with support for units: GB, MB, KB, B
|
||||||
|
|
||||||
|
Args:
|
||||||
|
value: Input value (int, str, or other convertible types)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Size in bytes
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For invalid format, missing number, or negative values
|
||||||
|
TypeError: For unsupported input types
|
||||||
|
"""
|
||||||
|
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
elif not isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return int(value)
|
||||||
|
except (TypeError, ValueError) as e:
|
||||||
|
raise TypeError(
|
||||||
|
f"Unsupported type for global_segment_size: {type(value)}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
cleaned_input = value.strip().lower()
|
||||||
|
if not cleaned_input:
|
||||||
|
raise ValueError("global segment size cannot be empty.")
|
||||||
|
|
||||||
|
UNIT_MULTIPLIERS = {
|
||||||
|
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||||
|
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||||
|
'kb': 1024, # 1 KB = 1024 bytes
|
||||||
|
'b': 1 # 1 B = 1 byte
|
||||||
|
}
|
||||||
|
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||||
|
match = re.match(pattern, cleaned_input)
|
||||||
|
|
||||||
|
if not match:
|
||||||
|
raise ValueError(f"Invalid format: '{value}'")
|
||||||
|
|
||||||
|
number_str = match.group(1)
|
||||||
|
unit = match.group(2) or 'b'
|
||||||
|
|
||||||
|
multiplier = UNIT_MULTIPLIERS[unit]
|
||||||
|
return _convert_to_bytes(number_str, multiplier, value)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||||
|
original_input: str) -> int:
|
||||||
|
"""
|
||||||
|
Convert numeric string to byte count
|
||||||
|
|
||||||
|
Args:
|
||||||
|
number_str: Numeric portion of input
|
||||||
|
multiplier: Unit conversion factor
|
||||||
|
original_input: Original input string (for error messages)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Byte count
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: For invalid numbers or negative results
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
numeric_value = float(number_str)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||||
|
# Calculate byte count
|
||||||
|
try:
|
||||||
|
byte_count = int(numeric_value * multiplier)
|
||||||
|
except OverflowError:
|
||||||
|
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||||
|
return byte_count
|
||||||
|
|||||||
Reference in New Issue
Block a user