diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py index 0e6a0355..11674b92 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py @@ -14,6 +14,7 @@ from vllm.utils.network_utils import get_ip from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB @@ -63,10 +64,14 @@ class MooncakeBackend(Backend): def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: - config = ReplicateConfig() - config.preferred_segment = self.local_seg - config.prefer_alloc_in_same_node = True - res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config) + soc_version = get_ascend_device_type() + if soc_version in {AscendDeviceType.A2}: + res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes) + else: + config = ReplicateConfig() + config.preferred_segment = self.local_seg + config.prefer_alloc_in_same_node = True + res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config) for value in res: if value < 0: logger.error(f"Failed to put key {keys},res:{res}") @@ -75,7 +80,11 @@ class MooncakeBackend(Backend): def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): try: - res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True) + soc_version = get_ascend_device_type() + if soc_version in {AscendDeviceType.A2}: + res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes) + else: + res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True) for value in res: if value < 0: logger.error(f"Failed to get key {keys}, res:{res}") diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py index c3f016b7..7dadd5ae 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/pool_worker.py @@ -30,6 +30,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import KVCacheStoreSendingThread, KVTransferThread, ) +from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type backend_map = { "mooncake": { @@ -97,6 +98,12 @@ class KVPoolWorker: self.head_or_tp_rank = self.tp_rank self.put_step = 1 + soc_version = get_ascend_device_type() + # be removed later + if self.backend == "mooncake" and soc_version in {AscendDeviceType.A3}: + self.head_or_tp_rank = self.tp_rank + self.put_step = 1 + self.metadata = KeyMetadata( model_config.model.rstrip("/").split("/")[-1], self.head_or_tp_rank, @@ -140,11 +147,6 @@ class KVPoolWorker: backend_module = importlib.import_module(backend_path) real_backend = getattr(backend_module, backend_name) - # be removed later - if self.backend == "mooncake": - self.head_or_tp_rank = self.tp_rank - self.put_step = 1 - self.m_store = real_backend( # type: ignore[misc] parallel_config )