[Feat](Mooncake) Supports multiple input suffixes for global_segment_size (#3690)
### What this PR does / why we need it?
- global_segment_size and local_buffer_size use constants for unified
management.
- Newly added support for input formats ending with GB, MB, KB, and B,
while being compatible with existing input methods.
### Does this PR introduce _any_ user-facing change?
- Users can use new input methods
- The documentation has also been modified
### How was this patch tested?
- vLLM version: v0.11.0
- vLLM main:
83f478bb19
---------
Signed-off-by: 李子琦 <liziqi_ing@163.com>
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
* Software:
|
||||
* Python >= 3.9, < 3.12
|
||||
* CANN >= 8.3.rc1
|
||||
* PyTorch == 2.7.1, torch-npu == 2.7.1
|
||||
* PyTorch >= 2.7.1, torch-npu >= 2.7.1.dev20250724
|
||||
* vLLM:main branch
|
||||
* vLLM-Ascend:main branch
|
||||
* Mooncake:main branch
|
||||
@@ -41,7 +41,7 @@ The environment variable **MOONCAKE_CONFIG_PATH** is configured to the full path
|
||||
"use_ascend_direct": true,
|
||||
"alloc_in_same_node": true,
|
||||
"master_server_address": "xx.xx.xx.xx:50088",
|
||||
"global_segment_size": 30000000000
|
||||
"global_segment_size": "1GB" (1024MB/1048576KB/1073741824B/1073741824)
|
||||
}
|
||||
```
|
||||
|
||||
|
||||
68
tests/ut/distributed/mooncake/test_config_data.py
Normal file
68
tests/ut/distributed/mooncake/test_config_data.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import unittest
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import (
|
||||
_convert_to_bytes, _parse_global_segment_size)
|
||||
|
||||
|
||||
class TestParseGlobalSegmentSize(unittest.TestCase):
|
||||
|
||||
def test_int_input(self):
|
||||
self.assertEqual(_parse_global_segment_size(1024), 1024)
|
||||
self.assertEqual(_parse_global_segment_size(0), 0)
|
||||
|
||||
def test_gb_unit(self):
|
||||
self.assertEqual(_parse_global_segment_size("2GB"), 2 * 1024**3)
|
||||
self.assertEqual(_parse_global_segment_size("1.5GB"),
|
||||
int(1.5 * 1024**3))
|
||||
self.assertEqual(_parse_global_segment_size(" 2 GB "), 2 * 1024**3)
|
||||
|
||||
def test_gb_unit_edge_cases(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_parse_global_segment_size("GB")
|
||||
with self.assertRaises(ValueError):
|
||||
_parse_global_segment_size("abcGB")
|
||||
|
||||
def test_mb_unit(self):
|
||||
self.assertEqual(_parse_global_segment_size("512MB"), 512 * 1024**2)
|
||||
self.assertEqual(_parse_global_segment_size("0.5MB"),
|
||||
int(0.5 * 1024**2))
|
||||
self.assertEqual(_parse_global_segment_size("1024MB"), 1024 * 1024**2)
|
||||
|
||||
def test_kb_unit(self):
|
||||
self.assertEqual(_parse_global_segment_size("256KB"), 256 * 1024)
|
||||
self.assertEqual(_parse_global_segment_size("1.25KB"),
|
||||
int(1.25 * 1024))
|
||||
|
||||
def test_b_unit(self):
|
||||
self.assertEqual(_parse_global_segment_size("4096B"), 4096)
|
||||
self.assertEqual(_parse_global_segment_size("1024b"), 1024)
|
||||
|
||||
def test_no_unit(self):
|
||||
self.assertEqual(_parse_global_segment_size("2048"), 2048)
|
||||
self.assertEqual(_parse_global_segment_size("0"), 0)
|
||||
|
||||
def test_non_string_non_int_input(self):
|
||||
self.assertEqual(_parse_global_segment_size(2048.0), 2048)
|
||||
self.assertEqual(_parse_global_segment_size(True), 1)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_parse_global_segment_size(None)
|
||||
|
||||
with self.assertRaises(TypeError):
|
||||
_parse_global_segment_size({"size": 1024})
|
||||
|
||||
|
||||
class TestConvertToBytes(unittest.TestCase):
|
||||
|
||||
def test_valid_conversion(self):
|
||||
self.assertEqual(_convert_to_bytes("10", 1, "10"), 10)
|
||||
self.assertEqual(_convert_to_bytes("1.5", 1024, "1.5KB"),
|
||||
int(1.5 * 1024))
|
||||
self.assertEqual(_convert_to_bytes("0", 1024**3, "0GB"), 0)
|
||||
|
||||
def test_invalid_numbers(self):
|
||||
with self.assertRaises(ValueError):
|
||||
_convert_to_bytes("abc", 1, "abc")
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
_convert_to_bytes("1.2.3", 1024, "1.2.3KB")
|
||||
@@ -2,6 +2,7 @@ import array
|
||||
import hashlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
@@ -11,6 +12,9 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import \
|
||||
from vllm.utils import cdiv, logger
|
||||
from vllm.v1.core.sched.output import NewRequestData
|
||||
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
|
||||
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
|
||||
|
||||
|
||||
@dataclass
|
||||
class MooncakeEngineMetadata:
|
||||
@@ -419,7 +423,7 @@ class LasyerMultiBlockReqMeta:
|
||||
class MooncakeStoreConfig:
|
||||
local_hostname: str
|
||||
metadata_server: str
|
||||
global_segment_size: int
|
||||
global_segment_size: Union[int, str]
|
||||
local_buffer_size: int
|
||||
protocol: str
|
||||
device_name: str
|
||||
@@ -433,8 +437,11 @@ class MooncakeStoreConfig:
|
||||
return MooncakeStoreConfig(
|
||||
local_hostname=config.get("local_hostname"),
|
||||
metadata_server=config.get("metadata_server"),
|
||||
global_segment_size=config.get("global_segment_size", 3355443200),
|
||||
local_buffer_size=config.get("local_buffer_size", 1073741824),
|
||||
global_segment_size=_parse_global_segment_size(
|
||||
config.get("global_segment_size",
|
||||
DEFAULT_GLOBAL_SEGMENT_SIZE)),
|
||||
local_buffer_size=(config.get("local_buffer_size",
|
||||
DEFAULT_LOCAL_BUFFER_SIZE)),
|
||||
protocol=config.get("protocol", "tcp"),
|
||||
device_name=config.get("device_name", ""),
|
||||
master_server_address=config.get("master_server_address"),
|
||||
@@ -446,4 +453,81 @@ class MooncakeStoreConfig:
|
||||
if not config_path:
|
||||
raise ValueError(
|
||||
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set.")
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
return MooncakeStoreConfig.from_file(config_path)
|
||||
|
||||
|
||||
def _parse_global_segment_size(value) -> int:
|
||||
"""
|
||||
Parse storage size strings with support for units: GB, MB, KB, B
|
||||
|
||||
Args:
|
||||
value: Input value (int, str, or other convertible types)
|
||||
|
||||
Returns:
|
||||
int: Size in bytes
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid format, missing number, or negative values
|
||||
TypeError: For unsupported input types
|
||||
"""
|
||||
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
elif not isinstance(value, str):
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError) as e:
|
||||
raise TypeError(
|
||||
f"Unsupported type for global_segment_size: {type(value)}"
|
||||
) from e
|
||||
|
||||
cleaned_input = value.strip().lower()
|
||||
if not cleaned_input:
|
||||
raise ValueError("global segment size cannot be empty.")
|
||||
|
||||
UNIT_MULTIPLIERS = {
|
||||
'gb': 1024**3, # 1 GB = 1024^3 bytes
|
||||
'mb': 1024**2, # 1 MB = 1024^2 bytes
|
||||
'kb': 1024, # 1 KB = 1024 bytes
|
||||
'b': 1 # 1 B = 1 byte
|
||||
}
|
||||
pattern = r'^\s*([\d.]+)\s*(gb|mb|kb|b)?\s*$'
|
||||
match = re.match(pattern, cleaned_input)
|
||||
|
||||
if not match:
|
||||
raise ValueError(f"Invalid format: '{value}'")
|
||||
|
||||
number_str = match.group(1)
|
||||
unit = match.group(2) or 'b'
|
||||
|
||||
multiplier = UNIT_MULTIPLIERS[unit]
|
||||
return _convert_to_bytes(number_str, multiplier, value)
|
||||
|
||||
|
||||
def _convert_to_bytes(number_str: str, multiplier: int,
|
||||
original_input: str) -> int:
|
||||
"""
|
||||
Convert numeric string to byte count
|
||||
|
||||
Args:
|
||||
number_str: Numeric portion of input
|
||||
multiplier: Unit conversion factor
|
||||
original_input: Original input string (for error messages)
|
||||
|
||||
Returns:
|
||||
int: Byte count
|
||||
|
||||
Raises:
|
||||
ValueError: For invalid numbers or negative results
|
||||
"""
|
||||
try:
|
||||
numeric_value = float(number_str)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
f"Invalid numeric value '{number_str}' in: '{original_input}'")
|
||||
# Calculate byte count
|
||||
try:
|
||||
byte_count = int(numeric_value * multiplier)
|
||||
except OverflowError:
|
||||
raise ValueError(f"Storage size too large: '{original_input}'")
|
||||
return byte_count
|
||||
|
||||
Reference in New Issue
Block a user