[KVPool][BugFix] Correctly initialize head_or_tp_rank for mooncake backend (#6498)

### What this PR does / why we need it?
The problem that the local priority is not used in the A2 environment on
the Mooncake node is resolved.

- vLLM version: v0.15.0
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.15.0

---------

Signed-off-by: 房建伟 <fangjianwei@fangjianweideMacBook-Air.local>
Co-authored-by: Pz1116 <zpbzpb123123@gmail.com>
This commit is contained in:
DreamerLeader
2026-02-25 14:22:00 +08:00
committed by GitHub
parent 3da2ba22eb
commit 812c722cfb
2 changed files with 21 additions and 10 deletions

View File

@@ -14,6 +14,7 @@ from vllm.utils.network_utils import get_ip
from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.backend.backend import Backend
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB DEFAULT_GLOBAL_SEGMENT_SIZE = 3355443200 # 3.125 GiB
DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB DEFAULT_LOCAL_BUFFER_SIZE = 1073741824 # 1.0 GiB
@@ -63,10 +64,14 @@ class MooncakeBackend(Backend):
def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): def put(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try: try:
config = ReplicateConfig() soc_version = get_ascend_device_type()
config.preferred_segment = self.local_seg if soc_version in {AscendDeviceType.A2}:
config.prefer_alloc_in_same_node = True res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes)
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config) else:
config = ReplicateConfig()
config.preferred_segment = self.local_seg
config.prefer_alloc_in_same_node = True
res = self.store.batch_put_from_multi_buffers(keys, addrs, sizes, config)
for value in res: for value in res:
if value < 0: if value < 0:
logger.error(f"Failed to put key {keys},res:{res}") logger.error(f"Failed to put key {keys},res:{res}")
@@ -75,7 +80,11 @@ class MooncakeBackend(Backend):
def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]): def get(self, keys: list[str], addrs: list[list[int]], sizes: list[list[int]]):
try: try:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True) soc_version = get_ascend_device_type()
if soc_version in {AscendDeviceType.A2}:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes)
else:
res = self.store.batch_get_into_multi_buffers(keys, addrs, sizes, True)
for value in res: for value in res:
if value < 0: if value < 0:
logger.error(f"Failed to get key {keys}, res:{res}") logger.error(f"Failed to get key {keys}, res:{res}")

View File

@@ -30,6 +30,7 @@ from vllm_ascend.distributed.kv_transfer.kv_pool.ascend_store.kv_transfer import
KVCacheStoreSendingThread, KVCacheStoreSendingThread,
KVTransferThread, KVTransferThread,
) )
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
backend_map = { backend_map = {
"mooncake": { "mooncake": {
@@ -97,6 +98,12 @@ class KVPoolWorker:
self.head_or_tp_rank = self.tp_rank self.head_or_tp_rank = self.tp_rank
self.put_step = 1 self.put_step = 1
soc_version = get_ascend_device_type()
# be removed later
if self.backend == "mooncake" and soc_version in {AscendDeviceType.A3}:
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.metadata = KeyMetadata( self.metadata = KeyMetadata(
model_config.model.rstrip("/").split("/")[-1], model_config.model.rstrip("/").split("/")[-1],
self.head_or_tp_rank, self.head_or_tp_rank,
@@ -140,11 +147,6 @@ class KVPoolWorker:
backend_module = importlib.import_module(backend_path) backend_module = importlib.import_module(backend_path)
real_backend = getattr(backend_module, backend_name) real_backend = getattr(backend_module, backend_name)
# be removed later
if self.backend == "mooncake":
self.head_or_tp_rank = self.tp_rank
self.put_step = 1
self.m_store = real_backend( # type: ignore[misc] self.m_store = real_backend( # type: ignore[misc]
parallel_config parallel_config
) )