[feature] kv transfer support of ascend npu (#7795)

Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
ronnie_zheng
2025-07-11 10:07:51 +03:00
committed by GitHub
parent 615553079d
commit 86044712c6
10 changed files with 267 additions and 53 deletions

View File

@@ -0,0 +1,6 @@
from sglang.srt.disaggregation.ascend.conn import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)

View File

@@ -0,0 +1,44 @@
import logging
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
from sglang.srt.disaggregation.mooncake.conn import (
MooncakeKVBootstrapServer,
MooncakeKVManager,
MooncakeKVReceiver,
MooncakeKVSender,
)
from sglang.srt.utils import get_local_ip_by_remote
logger = logging.getLogger(__name__)
class AscendKVManager(MooncakeKVManager):
def init_engine(self):
# TransferEngine initialized on ascend.
local_ip = get_local_ip_by_remote()
self.engine = AscendTransferEngine(
hostname=local_ip,
npu_id=self.kv_args.gpu_id,
disaggregation_mode=self.disaggregation_mode,
)
def register_buffer_to_engine(self):
self.engine.register(
self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
)
# The Ascend backend optimize batch registration for small memory blocks.
self.engine.batch_register(
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
)
class AscendKVSender(MooncakeKVSender):
pass
class AscendKVReceiver(MooncakeKVReceiver):
pass
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
pass

View File

@@ -0,0 +1,58 @@
import logging
import os
from typing import List, Optional
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
from sglang.srt.disaggregation.utils import DisaggregationMode
logger = logging.getLogger(__name__)
class AscendTransferEngine(MooncakeTransferEngine):
def __init__(
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
):
try:
from mf_adapter import TransferEngine
except ImportError as e:
raise ImportError(
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
) from e
self.engine = TransferEngine()
self.hostname = hostname
self.npu_id = npu_id
# Centralized storage address of the AscendTransferEngine
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
if disaggregation_mode == DisaggregationMode.PREFILL:
self.role = "Prefill"
elif disaggregation_mode == DisaggregationMode.DECODE:
self.role = "Decode"
else:
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
self.initialize()
def initialize(self) -> None:
"""Initialize the ascend transfer instance."""
ret_value = self.engine.initialize(
self.store_url,
self.session_id,
self.role,
self.npu_id,
)
if ret_value != 0:
logger.error("Ascend Transfer Engine initialization failed.")
raise RuntimeError("Ascend Transfer Engine initialization failed.")
def batch_register(self, ptrs: List[int], lengths: List[int]):
try:
ret_value = self.engine.batch_register_memory(ptrs, lengths)
except Exception:
# Mark register as failed
ret_value = -1
if ret_value != 0:
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")

View File

@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager):
):
self.kv_args = args
self.local_ip = get_local_ip_auto()
self.engine = MooncakeTransferEngine(
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
self.is_mla_backend = is_mla_backend
self.disaggregation_mode = disaggregation_mode
self.init_engine()
# for p/d multi node infer
self.bootstrap_port = server_args.disaggregation_bootstrap_port
self.dist_init_addr = server_args.dist_init_addr
@@ -225,6 +221,13 @@ class MooncakeKVManager(BaseKVManager):
self.failure_records: Dict[int, str] = {}
self.failure_lock = threading.Lock()
def init_engine(self):
self.engine = MooncakeTransferEngine(
hostname=self.local_ip,
gpu_id=self.kv_args.gpu_id,
ib_device=self.kv_args.ib_device,
)
def register_buffer_to_engine(self):
for kv_data_ptr, kv_data_len in zip(
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens

View File

@@ -1,6 +1,8 @@
import logging
from typing import List, Optional
from sglang.srt.utils import get_bool_env_var, get_free_port
logger = logging.getLogger(__name__)
@@ -53,12 +55,21 @@ class MooncakeTransferEngine:
device_name: Optional[str],
) -> None:
"""Initialize the mooncake instance."""
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"rdma",
device_name if device_name is not None else "",
)
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"ascend",
device_name if device_name is not None else "",
)
else:
ret_value = self.engine.initialize(
hostname,
"P2PHANDSHAKE",
"rdma",
device_name if device_name is not None else "",
)
if ret_value != 0:
logger.error("Mooncake Transfer Engine initialization failed.")
raise RuntimeError("Mooncake Transfer Engine initialization failed.")

View File

@@ -15,7 +15,7 @@ import requests
import torch
import torch.distributed as dist
from sglang.srt.utils import get_ip
from sglang.srt.utils import get_ip, is_npu
if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import Req
@@ -94,8 +94,12 @@ class MetadataBuffers:
custom_mem_pool: torch.cuda.MemPool = None,
):
self.custom_mem_pool = custom_mem_pool
device = "cuda" if self.custom_mem_pool else "cpu"
device = "cpu"
if is_npu():
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
device = "npu"
elif self.custom_mem_pool:
device = "cuda"
with (
torch.cuda.use_mem_pool(self.custom_mem_pool)
if self.custom_mem_pool
@@ -200,6 +204,7 @@ class MetadataBuffers:
class TransferBackend(Enum):
MOONCAKE = "mooncake"
NIXL = "nixl"
ASCEND = "ascend"
FAKE = "fake"
@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.ASCEND:
from sglang.srt.disaggregation.ascend import (
AscendKVBootstrapServer,
AscendKVManager,
AscendKVReceiver,
AscendKVSender,
)
from sglang.srt.disaggregation.base import KVArgs
class_mapping = {
KVClassType.KVARGS: KVArgs,
KVClassType.MANAGER: AscendKVManager,
KVClassType.SENDER: AscendKVSender,
KVClassType.RECEIVER: (AscendKVReceiver),
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
}
return class_mapping.get(class_type)
elif transfer_backend == TransferBackend.NIXL:
from sglang.srt.disaggregation.base import KVArgs
from sglang.srt.disaggregation.nixl import (