diff --git a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py index 25a103ca..3375e741 100644 --- a/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py +++ b/vllm_ascend/distributed/kv_transfer/kv_pool/ascend_store/backend/mooncake_backend.py @@ -2,6 +2,8 @@ import json import os import re +import torch + from dataclasses import dataclass from typing import Union @@ -32,6 +34,7 @@ class MooncakeBackend(Backend): "to run vLLM with MooncakeConnector.") from e self.config = MooncakeStoreConfig.load_from_env() self.store = MooncakeDistributedStore() + self.rank = parallel_config.rank if self.config.protocol == "ascend": local_hostname = get_ip() transfer_engine = global_te.get_transfer_engine(local_hostname, @@ -50,6 +53,10 @@ class MooncakeBackend(Backend): logger.error(msg) raise RuntimeError(msg) + def set_device(self): + device = torch.device(f"npu:{self.rank}") + torch.npu.set_device(device) + def register_buffer(self, ptrs: list[int], lengths: list[int]): global_te.register_buffer(ptrs, lengths)