Mooncake store use adxl inferface (#3350)
Use adxl inferface in mooncake store, mooncake PR https://github.com/kvcache-ai/Mooncake/pull/929 - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: LCAIZJ <leichao139636@163.com>
This commit is contained in:
@@ -2,11 +2,13 @@
|
||||
import os
|
||||
|
||||
# Third Party
|
||||
from mooncake.store import ReplicateConfig # type: ignore
|
||||
from vllm.config import ParallelConfig
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank
|
||||
from vllm.utils import logger
|
||||
from vllm.utils import get_ip, logger
|
||||
|
||||
from vllm_ascend.distributed.mooncake.config_data import MooncakeEngineKey
|
||||
from vllm_ascend.distributed.mooncake.transfer_engine import get_global_te
|
||||
|
||||
from .config_data import MooncakeStoreConfig
|
||||
|
||||
@@ -36,17 +38,28 @@ class Mooncakestore():
|
||||
assert len(device_ids_list) > tp_rank
|
||||
device_id = device_ids_list[tp_rank]
|
||||
self.config = MooncakeStoreConfig.load_from_env()
|
||||
if self.config.protocol == "ascend":
|
||||
local_hostname = self.config.local_hostname + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
else:
|
||||
local_hostname = self.config.local_hostname
|
||||
self.store = MooncakeDistributedStore()
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol, self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
if self.config.protocol == "ascend" and not self.config.use_ascend_direct:
|
||||
local_hostname = get_ip() + ":" + str(BASE_PORT + int(device_id)) + \
|
||||
":npu_" + str(device_id)
|
||||
ret = self.store.setup(local_hostname, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address)
|
||||
else:
|
||||
local_hostname = get_ip()
|
||||
transfer_engine = get_global_te(local_hostname, device_name=None)
|
||||
self.local_seg = local_hostname + ":" + str(
|
||||
transfer_engine.get_rpc_port())
|
||||
ret = self.store.setup(self.local_seg, self.config.metadata_server,
|
||||
self.config.global_segment_size,
|
||||
self.config.local_buffer_size,
|
||||
self.config.protocol,
|
||||
self.config.device_name,
|
||||
self.config.master_server_address,
|
||||
transfer_engine.get_engine())
|
||||
if ret != 0:
|
||||
msg = "Initialize mooncake failed."
|
||||
logger.error(msg)
|
||||
@@ -61,6 +74,31 @@ class Mooncakestore():
|
||||
def batch_exists(self, keys: list[str]) -> list[bool]:
|
||||
return self.store.batch_is_exist(keys)
|
||||
|
||||
def get_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
res = self.store.batch_get_into_multi_buffers(
|
||||
keys, addrs, sizes, True)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to get key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get key {keys}. {e}")
|
||||
|
||||
def put_batch(self, keys: list[str], addrs: list[list[int]],
|
||||
sizes: list[list[int]], block_ids: list[int]):
|
||||
try:
|
||||
config = ReplicateConfig()
|
||||
config.preferred_segment = self.local_seg
|
||||
config.prefer_alloc_in_same_node = True
|
||||
res = self.store.batch_put_from_multi_buffers(
|
||||
keys, addrs, sizes, config)
|
||||
for value in res:
|
||||
if value < 0:
|
||||
logger.error(f"Failed to put key {keys},res:{res}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to put key {keys},error:{e}")
|
||||
|
||||
def get(self, key: MooncakeEngineKey, addr: list[int], size: list[int]):
|
||||
expect_res = sum(size)
|
||||
key_str = key.to_string()
|
||||
@@ -85,4 +123,4 @@ class Mooncakestore():
|
||||
|
||||
def close(self):
|
||||
self.store.close()
|
||||
logger.info("Closed the mooncake store connection")
|
||||
logger.info("Closed the mooncake store connection")
|
||||
|
||||
Reference in New Issue
Block a user