[feature] kv transfer support of ascend npu (#7795)
Co-authored-by: liupeng <liupeng374@huawei.com>
This commit is contained in:
6
python/sglang/srt/disaggregation/ascend/__init__.py
Normal file
6
python/sglang/srt/disaggregation/ascend/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
from sglang.srt.disaggregation.ascend.conn import (
|
||||
AscendKVBootstrapServer,
|
||||
AscendKVManager,
|
||||
AscendKVReceiver,
|
||||
AscendKVSender,
|
||||
)
|
||||
44
python/sglang/srt/disaggregation/ascend/conn.py
Normal file
44
python/sglang/srt/disaggregation/ascend/conn.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import logging
|
||||
|
||||
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
||||
from sglang.srt.disaggregation.mooncake.conn import (
|
||||
MooncakeKVBootstrapServer,
|
||||
MooncakeKVManager,
|
||||
MooncakeKVReceiver,
|
||||
MooncakeKVSender,
|
||||
)
|
||||
from sglang.srt.utils import get_local_ip_by_remote
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AscendKVManager(MooncakeKVManager):
|
||||
def init_engine(self):
|
||||
# TransferEngine initialized on ascend.
|
||||
local_ip = get_local_ip_by_remote()
|
||||
self.engine = AscendTransferEngine(
|
||||
hostname=local_ip,
|
||||
npu_id=self.kv_args.gpu_id,
|
||||
disaggregation_mode=self.disaggregation_mode,
|
||||
)
|
||||
|
||||
def register_buffer_to_engine(self):
|
||||
self.engine.register(
|
||||
self.kv_args.kv_data_ptrs[0], sum(self.kv_args.kv_data_lens)
|
||||
)
|
||||
# The Ascend backend optimize batch registration for small memory blocks.
|
||||
self.engine.batch_register(
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
)
|
||||
|
||||
|
||||
class AscendKVSender(MooncakeKVSender):
|
||||
pass
|
||||
|
||||
|
||||
class AscendKVReceiver(MooncakeKVReceiver):
|
||||
pass
|
||||
|
||||
|
||||
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
|
||||
pass
|
||||
58
python/sglang/srt/disaggregation/ascend/transfer_engine.py
Normal file
58
python/sglang/srt/disaggregation/ascend/transfer_engine.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
||||
from sglang.srt.disaggregation.utils import DisaggregationMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AscendTransferEngine(MooncakeTransferEngine):
|
||||
|
||||
def __init__(
|
||||
self, hostname: str, npu_id: int, disaggregation_mode: DisaggregationMode
|
||||
):
|
||||
try:
|
||||
from mf_adapter import TransferEngine
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install mf_adapter, for details, see docs/backend/pd_disaggregation.md"
|
||||
) from e
|
||||
|
||||
self.engine = TransferEngine()
|
||||
self.hostname = hostname
|
||||
self.npu_id = npu_id
|
||||
|
||||
# Centralized storage address of the AscendTransferEngine
|
||||
self.store_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
if disaggregation_mode == DisaggregationMode.PREFILL:
|
||||
self.role = "Prefill"
|
||||
elif disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.role = "Decode"
|
||||
else:
|
||||
logger.error(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
||||
raise ValueError(f"Unsupported DisaggregationMode: {disaggregation_mode}")
|
||||
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
||||
self.initialize()
|
||||
|
||||
def initialize(self) -> None:
|
||||
"""Initialize the ascend transfer instance."""
|
||||
ret_value = self.engine.initialize(
|
||||
self.store_url,
|
||||
self.session_id,
|
||||
self.role,
|
||||
self.npu_id,
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Ascend Transfer Engine initialization failed.")
|
||||
raise RuntimeError("Ascend Transfer Engine initialization failed.")
|
||||
|
||||
def batch_register(self, ptrs: List[int], lengths: List[int]):
|
||||
try:
|
||||
ret_value = self.engine.batch_register_memory(ptrs, lengths)
|
||||
except Exception:
|
||||
# Mark register as failed
|
||||
ret_value = -1
|
||||
if ret_value != 0:
|
||||
logger.debug(f"Ascend memory registration for ptr {ptrs} failed.")
|
||||
@@ -132,13 +132,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
):
|
||||
self.kv_args = args
|
||||
self.local_ip = get_local_ip_auto()
|
||||
self.engine = MooncakeTransferEngine(
|
||||
hostname=self.local_ip,
|
||||
gpu_id=self.kv_args.gpu_id,
|
||||
ib_device=self.kv_args.ib_device,
|
||||
)
|
||||
self.is_mla_backend = is_mla_backend
|
||||
self.disaggregation_mode = disaggregation_mode
|
||||
self.init_engine()
|
||||
# for p/d multi node infer
|
||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||
self.dist_init_addr = server_args.dist_init_addr
|
||||
@@ -225,6 +221,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.failure_records: Dict[int, str] = {}
|
||||
self.failure_lock = threading.Lock()
|
||||
|
||||
def init_engine(self):
|
||||
self.engine = MooncakeTransferEngine(
|
||||
hostname=self.local_ip,
|
||||
gpu_id=self.kv_args.gpu_id,
|
||||
ib_device=self.kv_args.ib_device,
|
||||
)
|
||||
|
||||
def register_buffer_to_engine(self):
|
||||
for kv_data_ptr, kv_data_len in zip(
|
||||
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
from typing import List, Optional
|
||||
|
||||
from sglang.srt.utils import get_bool_env_var, get_free_port
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -53,12 +55,21 @@ class MooncakeTransferEngine:
|
||||
device_name: Optional[str],
|
||||
) -> None:
|
||||
"""Initialize the mooncake instance."""
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"rdma",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
if get_bool_env_var("ENABLE_ASCEND_TRANSFER_WITH_MOONCAKE", "false"):
|
||||
hostname += f":{get_free_port()}:npu_{self.gpu_id}"
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"ascend",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
else:
|
||||
ret_value = self.engine.initialize(
|
||||
hostname,
|
||||
"P2PHANDSHAKE",
|
||||
"rdma",
|
||||
device_name if device_name is not None else "",
|
||||
)
|
||||
if ret_value != 0:
|
||||
logger.error("Mooncake Transfer Engine initialization failed.")
|
||||
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
||||
|
||||
@@ -15,7 +15,7 @@ import requests
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from sglang.srt.utils import get_ip
|
||||
from sglang.srt.utils import get_ip, is_npu
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sglang.srt.managers.schedule_batch import Req
|
||||
@@ -94,8 +94,12 @@ class MetadataBuffers:
|
||||
custom_mem_pool: torch.cuda.MemPool = None,
|
||||
):
|
||||
self.custom_mem_pool = custom_mem_pool
|
||||
device = "cuda" if self.custom_mem_pool else "cpu"
|
||||
|
||||
device = "cpu"
|
||||
if is_npu():
|
||||
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
||||
device = "npu"
|
||||
elif self.custom_mem_pool:
|
||||
device = "cuda"
|
||||
with (
|
||||
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
||||
if self.custom_mem_pool
|
||||
@@ -200,6 +204,7 @@ class MetadataBuffers:
|
||||
class TransferBackend(Enum):
|
||||
MOONCAKE = "mooncake"
|
||||
NIXL = "nixl"
|
||||
ASCEND = "ascend"
|
||||
FAKE = "fake"
|
||||
|
||||
|
||||
@@ -231,6 +236,23 @@ def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
|
||||
KVClassType.BOOTSTRAP_SERVER: MooncakeKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
elif transfer_backend == TransferBackend.ASCEND:
|
||||
from sglang.srt.disaggregation.ascend import (
|
||||
AscendKVBootstrapServer,
|
||||
AscendKVManager,
|
||||
AscendKVReceiver,
|
||||
AscendKVSender,
|
||||
)
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
|
||||
class_mapping = {
|
||||
KVClassType.KVARGS: KVArgs,
|
||||
KVClassType.MANAGER: AscendKVManager,
|
||||
KVClassType.SENDER: AscendKVSender,
|
||||
KVClassType.RECEIVER: (AscendKVReceiver),
|
||||
KVClassType.BOOTSTRAP_SERVER: AscendKVBootstrapServer,
|
||||
}
|
||||
return class_mapping.get(class_type)
|
||||
elif transfer_backend == TransferBackend.NIXL:
|
||||
from sglang.srt.disaggregation.base import KVArgs
|
||||
from sglang.srt.disaggregation.nixl import (
|
||||
|
||||
@@ -285,6 +285,20 @@ class TokenizerManager:
|
||||
self.bootstrap_server = kv_bootstrap_server_class(
|
||||
self.server_args.disaggregation_bootstrap_port
|
||||
)
|
||||
is_create_store = (
|
||||
self.server_args.node_rank == 0
|
||||
and self.server_args.disaggregation_transfer_backend == "ascend"
|
||||
)
|
||||
if is_create_store:
|
||||
try:
|
||||
from mf_adapter import create_config_store
|
||||
|
||||
ascend_url = os.getenv("ASCEND_MF_STORE_URL")
|
||||
create_config_store(ascend_url)
|
||||
except Exception as e:
|
||||
error_message = f"Failed create mf store, invalid ascend_url."
|
||||
error_message += f" With exception {e}"
|
||||
raise error_message
|
||||
|
||||
# For load balancing
|
||||
self.current_load = 0
|
||||
|
||||
@@ -604,32 +604,49 @@ class AscendTokenToKVPool(MHATokenToKVPool):
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# [size, head_num, head_dim] for each layer
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.k_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
self.v_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(self.layer_num)
|
||||
]
|
||||
# Continuous memory improves the efficiency of Ascend`s transmission backend,
|
||||
# while other backends remain unchanged.
|
||||
self.kv_buffer = torch.zeros(
|
||||
(
|
||||
2,
|
||||
self.layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.head_num,
|
||||
self.head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
self.k_buffer = self.kv_buffer[0]
|
||||
self.v_buffer = self.kv_buffer[1]
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# layer_num x [seq_len, head_num, head_dim]
|
||||
# layer_num x [page_num, page_size, head_num, head_dim]
|
||||
kv_data_ptrs = [
|
||||
self.get_key_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).data_ptr()
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_data_lens = [
|
||||
self.get_key_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i).nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
kv_item_lens = [
|
||||
self.get_key_buffer(i)[0].nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
] + [
|
||||
self.get_value_buffer(i)[0].nbytes
|
||||
for i in range(self.start_layer, self.start_layer + self.layer_num)
|
||||
]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
@@ -969,18 +986,16 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
|
||||
with self.memory_saver_adapter.region(GPU_MEMORY_TYPE_KV_CACHE):
|
||||
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
||||
self.kv_buffer = [
|
||||
torch.zeros(
|
||||
(
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
for _ in range(layer_num)
|
||||
]
|
||||
self.kv_buffer = torch.zeros(
|
||||
(
|
||||
layer_num,
|
||||
self.size // self.page_size + 1,
|
||||
self.page_size,
|
||||
self.kv_lora_rank + self.qk_rope_head_dim,
|
||||
),
|
||||
dtype=self.store_dtype,
|
||||
device=self.device,
|
||||
)
|
||||
|
||||
self.layer_transfer_counter = None
|
||||
|
||||
@@ -990,6 +1005,14 @@ class AscendMLAPagedTokenToKVPool(MLATokenToKVPool):
|
||||
)
|
||||
self.mem_usage = kv_size / GB
|
||||
|
||||
# for disagg
|
||||
def get_contiguous_buf_infos(self):
|
||||
# MLA has only one kv_buffer, so only the information of this buffer needs to be returned.
|
||||
kv_data_ptrs = [self.kv_buffer[i].data_ptr() for i in range(self.layer_num)]
|
||||
kv_data_lens = [self.kv_buffer[i].nbytes for i in range(self.layer_num)]
|
||||
kv_item_lens = [self.kv_buffer[i][0].nbytes for i in range(self.layer_num)]
|
||||
return kv_data_ptrs, kv_data_lens, kv_item_lens
|
||||
|
||||
def set_kv_buffer(
|
||||
self,
|
||||
layer: RadixAttention,
|
||||
|
||||
@@ -1621,7 +1621,7 @@ class ServerArgs:
|
||||
"--disaggregation-transfer-backend",
|
||||
type=str,
|
||||
default=ServerArgs.disaggregation_transfer_backend,
|
||||
choices=["mooncake", "nixl"],
|
||||
choices=["mooncake", "nixl", "ascend"],
|
||||
help="The backend for disaggregation transfer. Default is mooncake.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
Reference in New Issue
Block a user