### What this PR does / why we need it? In multi-Tensor Parallel (TP) scenarios, the KV pool only queries the first GPU card. When keys on other cards are released, the query result still returns as successful, introducing accuracy issues. This PR modifies the KV pool's query logic to check all cards, resolving this problem. ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 Signed-off-by: fems14 <1804143737@qq.com>
127 lines
5.3 KiB
Python
127 lines
5.3 KiB
Python
# Standard
|
|
import os
|
|
|
|
# Third Party
|
|
from mooncake.store import ReplicateConfig # type: ignore
|
|
from vllm.config import ParallelConfig
|
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
|
from vllm.utils import get_ip, logger
|
|
|
|
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
|
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
|
|
|
from .config_data import MooncakeStoreConfig
|
|
|
|
METADATA_BYTES_LEN = 24
|
|
BASE_PORT = int(os.getenv("VLLM_BASE_PORT", "8790"))
|
|
|
|
|
|
class Mooncakestore():
|
|
|
|
def __init__(self, parallel_config: ParallelConfig):
|
|
try:
|
|
from mooncake.store import MooncakeDistributedStore # type: ignore
|
|
except ImportError as e:
|
|
raise ImportError(
|
|
"Please install mooncake by following the instructions at "
|
|
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
|
"to run vLLM with MooncakeConnector.") from e
|
|
tp_rank = get_tensor_model_parallel_rank()
|
|
tp_size = parallel_config.tensor_parallel_size
|
|
dp_rank = parallel_config.data_parallel_rank_local
|
|
all_device_ids = os.getenv("ASCEND_RT_VISIBLE_DEVICES", None)
|
|
if not all_device_ids:
|
|
device_ids_list = list(
|
|
range(dp_rank * tp_size, (dp_rank + 1) * tp_size))
|
|
else:
|
|
device_ids_list = list(map(int, all_device_ids.split(',')))
|
|
assert len(device_ids_list) > tp_rank
|
|
device_id = device_ids_list[tp_rank]
|
|
self.config = MooncakeStoreConfig.load_from_env()
|
|
self.store = MooncakeDistributedStore()
|
|
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
|
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
|
":npu_" + str(device_id)
|
|
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
|
self.config.global_segment_size,
|
|
self.config.local_buffer_size,
|
|
self.config.protocol,
|
|
self.config.device_name,
|
|
self.config.master_server_address)
|
|
else:
|
|
local_hostname = get_ip()
|
|
transfer_engine = get_global_te(local_hostname, device_name=None)
|
|
self.local_seg = local_hostname + ":" + str(
|
|
transfer_engine.get_rpc_port())
|
|
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
|
self.config.global_segment_size,
|
|
self.config.local_buffer_size,
|
|
self.config.protocol,
|
|
self.config.device_name,
|
|
self.config.master_server_address,
|
|
transfer_engine.get_engine())
|
|
if ret != 0:
|
|
msg = "Initialize mooncake failed."
|
|
logger.error(msg)
|
|
raise RuntimeError(msg)
|
|
|
|
def exists(self, key: MooncakeEngineKey) -> bool:
|
|
return self.store.is_exist(key.to_string()) == 1
|
|
|
|
def batch_exists(self, keys: list[str]) -> list[int]:
|
|
return self.store.batch_is_exist(keys)
|
|
|
|
def register_buffer(self, ptr, length):
|
|
return self.store.register_buffer(ptr, length)
|
|
|
|
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
|
sizes: list[list[int]], block_ids: list[int]):
|
|
try:
|
|
res = self.store.batch_get_into_multi_buffers(
|
|
keys, addrs, sizes, True)
|
|
for value in res:
|
|
if value < 0:
|
|
logger.error(f"Failed to get key {keys},res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to get key {keys}. {e}")
|
|
|
|
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
|
sizes: list[list[int]], block_ids: list[int]):
|
|
try:
|
|
config = ReplicateConfig()
|
|
config.preferred_segment = self.local_seg
|
|
config.prefer_alloc_in_same_node = True
|
|
res = self.store.batch_put_from_multi_buffers(
|
|
keys, addrs, sizes, config)
|
|
for value in res:
|
|
if value < 0:
|
|
logger.error(f"Failed to put key {keys},res:{res}")
|
|
except Exception as e:
|
|
logger.error(f"Failed to put key {keys},error:{e}")
|
|
|
|
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
|
expect_res = sum(size)
|
|
key_str = key.to_string()
|
|
try:
|
|
res = self.store.batch_get_into_ascend(key_str, addr, size)
|
|
if res[0] != expect_res:
|
|
logger.error(f"Failed to get key: [{key_str}] .")
|
|
except Exception:
|
|
logger.error(f"Failed to get key: [{key_str}] .")
|
|
return res
|
|
|
|
def put(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
|
key_str = key.to_string()
|
|
try:
|
|
ret = self.store.batch_put_from_ascend(key_str, addr, size)
|
|
if ret[0] != 0:
|
|
logger.error(f"Failed to put key {key_str}.")
|
|
except Exception:
|
|
logger.error(f"Failed to put key {key_str}.")
|
|
|
|
return ret
|
|
|
|
def close(self):
|
|
self.store.close()
|
|
logger.info("Closed the mooncake store connection")
|