118 lines
4.0 KiB
Python
118 lines
4.0 KiB
Python
import concurrent.futures
|
|
import logging
|
|
from typing import List, Tuple
|
|
|
|
import numpy as np
|
|
import numpy.typing as npt
|
|
|
|
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
|
from sglang.srt.disaggregation.mooncake.conn import (
|
|
MooncakeKVBootstrapServer,
|
|
MooncakeKVManager,
|
|
MooncakeKVReceiver,
|
|
MooncakeKVSender,
|
|
)
|
|
from sglang.srt.utils import get_local_ip_by_remote
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class AscendKVManager(MooncakeKVManager):
|
|
def init_engine(self):
|
|
# TransferEngine initialized on ascend.
|
|
local_ip = get_local_ip_by_remote()
|
|
self.engine = AscendTransferEngine(
|
|
hostname=local_ip,
|
|
npu_id=self.kv_args.gpu_id,
|
|
disaggregation_mode=self.disaggregation_mode,
|
|
)
|
|
|
|
def register_buffer_to_engine(self):
|
|
self.engine.batch_register(self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens)
|
|
# The Ascend backend optimize batch registration for small memory blocks.
|
|
self.engine.batch_register(
|
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
|
)
|
|
|
|
def send_kvcache(
|
|
self,
|
|
mooncake_session_id: str,
|
|
prefill_kv_indices: npt.NDArray[np.int32],
|
|
dst_kv_ptrs: list[int],
|
|
dst_kv_indices: npt.NDArray[np.int32],
|
|
executor: concurrent.futures.ThreadPoolExecutor,
|
|
):
|
|
# Group by indices
|
|
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
|
prefill_kv_indices, dst_kv_indices
|
|
)
|
|
|
|
num_layers = len(self.kv_args.kv_data_ptrs)
|
|
layers_params = [
|
|
(
|
|
self.kv_args.kv_data_ptrs[layer_id],
|
|
dst_kv_ptrs[layer_id],
|
|
self.kv_args.kv_item_lens[layer_id],
|
|
)
|
|
for layer_id in range(num_layers)
|
|
]
|
|
|
|
def set_transfer_blocks(
|
|
src_ptr: int, dst_ptr: int, item_len: int
|
|
) -> List[Tuple[int, int, int]]:
|
|
transfer_blocks = []
|
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
|
length = item_len * len(prefill_index)
|
|
transfer_blocks.append((src_addr, dst_addr, length))
|
|
return transfer_blocks
|
|
|
|
# Worker function for processing a single layer
|
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
|
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
|
|
|
# Worker function for processing all layers in a batch
|
|
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
|
transfer_blocks = []
|
|
for src_ptr, dst_ptr, item_len in layers_params:
|
|
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
|
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
|
|
|
if self.enable_custom_mem_pool:
|
|
futures = [
|
|
executor.submit(
|
|
process_layer,
|
|
src_ptr,
|
|
dst_ptr,
|
|
item_len,
|
|
)
|
|
for (src_ptr, dst_ptr, item_len) in layers_params
|
|
]
|
|
for future in concurrent.futures.as_completed(futures):
|
|
status = future.result()
|
|
if status != 0:
|
|
for f in futures:
|
|
f.cancel()
|
|
return status
|
|
else:
|
|
# Combining all layers' params in one batch transfer is more efficient
|
|
# compared to using multiple threads
|
|
return process_layers(layers_params)
|
|
|
|
return 0
|
|
|
|
|
|
class AscendKVSender(MooncakeKVSender):
|
|
pass
|
|
|
|
|
|
class AscendKVReceiver(MooncakeKVReceiver):
|
|
pass
|
|
|
|
|
|
class AscendKVBootstrapServer(MooncakeKVBootstrapServer):
|
|
pass
|