2025-09-30 15:10:29 +08:00
# SPDX-License-Identifier: Apache-2.0
import contextlib
2025-10-30 22:21:11 +08:00
import copy
2025-09-30 15:10:29 +08:00
import hashlib
import math
import queue
import struct
import threading
import time
2025-10-30 22:21:11 +08:00
from collections import defaultdict , deque
2025-09-30 15:10:29 +08:00
from collections . abc import Iterator
from concurrent . futures import ThreadPoolExecutor
from dataclasses import dataclass
from typing import TYPE_CHECKING , Any , Callable , List , Optional , Tuple
import httpx
import msgspec
import numpy as np
import numpy . typing as npt
import torch
import zmq
from mooncake . engine import TransferEngine # type: ignore
2025-10-30 22:21:11 +08:00
from vllm . config import VllmConfig
2025-09-30 15:10:29 +08:00
from vllm . distributed . kv_transfer . kv_connector . v1 . base import (
KVConnectorBase_V1 , KVConnectorMetadata , KVConnectorRole )
from vllm . distributed . parallel_state import ( get_tensor_model_parallel_rank ,
get_tp_group , get_world_group )
from vllm . utils import get_ip , logger , make_zmq_path , make_zmq_socket
from vllm . v1 . core . sched . output import SchedulerOutput
import vllm_ascend . envs as envs_ascend
from vllm_ascend . ascend_config import get_ascend_config
from vllm_ascend . distributed . utils import ( align_memory ,
kv_alltoall_and_rearrange )
if TYPE_CHECKING :
from vllm . attention . backends . abstract import AttentionMetadata
from vllm . forward_context import ForwardContext
from vllm . v1 . core . kv_cache_manager import KVCacheBlocks
from vllm . v1 . request import Request
GET_META_MSG = b " get_meta_msg "
2025-10-30 22:21:11 +08:00
DONE_SENDING_MSG = b " done_sending_msg "
2025-09-30 15:10:29 +08:00
class MooncakeAgentMetadata ( msgspec . Struct , omit_defaults = True , dict = True ) :
te_rpc_port : int
kv_caches_base_addr : list [ int ]
@dataclass
class ReqMeta :
local_block_ids : list [ int ]
2025-10-30 22:21:11 +08:00
token_ids : list [ int ]
2025-09-30 15:10:29 +08:00
# Not None if layer-wise is disabled
2025-10-30 22:21:11 +08:00
remote_block_ids : list [ int ]
remote_engine_id : Optional [ str ]
2025-09-30 15:10:29 +08:00
remote_host : Optional [ str ]
remote_port : Optional [ int ]
2025-10-30 22:21:11 +08:00
remote_te_rpc_port : Optional [ int ]
remote_kv_caches_base_addr : Optional [ list [ int ] ]
2025-09-30 15:10:29 +08:00
metaserver : Optional [ str ]
class KVCacheSendingLayerThread ( threading . Thread ) :
2025-10-30 22:21:11 +08:00
def __init__ ( self ,
engine : TransferEngine ,
total_layers : int ,
ready_event : threading . Event ,
tp_rank : int ,
pd_head_ratio : int ,
num_head_replica : int ,
kv_cache_base_addr : list [ int ] ,
use_mla : bool ,
block_len : list [ int ] ,
first_kv_cache : torch . Tensor ,
callback_func : Callable [ . . . , None ] = lambda x : None ) :
2025-09-30 15:10:29 +08:00
super ( ) . __init__ ( daemon = True , name = " KVCacheSendingLayerThread " )
2025-10-30 22:21:11 +08:00
self . engine = engine
2025-09-30 15:10:29 +08:00
self . tp_rank = tp_rank
2025-10-30 22:21:11 +08:00
self . pd_head_ratio = pd_head_ratio
self . num_head_replica = num_head_replica
self . kv_caches_base_addr = kv_cache_base_addr
2025-09-30 15:10:29 +08:00
self . total_layers = total_layers
2025-10-30 22:21:11 +08:00
self . use_mla = use_mla
self . block_len = block_len
2025-09-30 15:10:29 +08:00
2025-10-30 22:21:11 +08:00
if self . pd_head_ratio > 1 :
# regesit kv buffer for tp inequal
alignment = 2 * 1024 * 1024
self . k_buffer = torch . zeros ( first_kv_cache . numel ( ) + alignment ,
dtype = first_kv_cache . dtype ,
device = first_kv_cache . device )
self . k_buffer = align_memory (
self . k_buffer , alignment ) [ : first_kv_cache . numel ( ) ] . view (
- 1 , first_kv_cache . shape [ - 1 ] )
self . v_buffer = torch . zeros ( first_kv_cache . numel ( ) + alignment ,
dtype = first_kv_cache . dtype ,
device = first_kv_cache . device )
self . v_buffer = align_memory (
self . v_buffer , alignment ) [ : first_kv_cache . numel ( ) ] . view (
- 1 , first_kv_cache . shape [ - 1 ] )
for tensor in ( self . k_buffer , self . v_buffer ) :
assert tensor . data_ptr (
) % alignment == 0 , " The address of the registered kv cache should be aligned to 2M "
ret_value = self . engine . register_memory (
tensor . data_ptr ( ) , tensor . numel ( ) )
logger . info (
f " Register memory for prefill when pd head ratio > 1 { tensor . data_ptr ( ) } { tensor . numel ( ) } { ret_value =} "
)
if ret_value != 0 :
raise RuntimeError ( " Mooncake memory registration failed. " )
2025-09-30 15:10:29 +08:00
2025-10-30 22:21:11 +08:00
self . send_queue = queue . Queue [ Tuple [ str , ReqMeta , int , torch . Tensor ,
2025-09-30 15:10:29 +08:00
torch . Tensor ] ] ( )
2025-10-30 22:21:11 +08:00
self . ready_event = ready_event
self . callback_func = callback_func
2025-09-30 15:10:29 +08:00
def run ( self ) :
local_rank = get_world_group ( ) . local_rank
device = torch . device ( f " npu: { local_rank } " )
torch . npu . set_device ( device )
2025-10-30 22:21:11 +08:00
self . ready_event . set ( )
2025-09-30 15:10:29 +08:00
while True :
2025-10-30 22:21:11 +08:00
req_id , req_meta , layer_index , key , value = self . send_queue . get ( )
self . _handle_request ( req_id , req_meta , layer_index , key , value )
2025-09-30 15:10:29 +08:00
2025-10-30 22:21:11 +08:00
def _handle_request ( self , req_id , req_meta , layer_index , key , value ) :
2025-09-30 15:10:29 +08:00
try :
logger . debug (
2025-10-30 22:21:11 +08:00
f " Starting to transfer KV cache for request { req_id } { req_meta . remote_te_rpc_port =} . "
)
self . _transfer_kv_cache ( req_id , req_meta , layer_index , key , value )
2025-09-30 15:10:29 +08:00
logger . debug (
2025-10-30 22:21:11 +08:00
f " Finished transferring KV cache for request { req_id } { req_meta . remote_te_rpc_port =} . "
)
2025-09-30 15:10:29 +08:00
except Exception as e :
logger . error ( " Failed to transfer KV cache for request "
2025-10-30 22:21:11 +08:00
f " { req_id } : { e } " )
def _transfer_kv_cache ( self , req_id , req_meta , layer_index , key , value ) :
2025-09-30 15:10:29 +08:00
# send kv layer to remote
2025-10-30 22:21:11 +08:00
if len ( req_meta . local_block_ids ) == 0 :
return
# not need to send kv cache
if self . tp_rank % self . num_head_replica != 0 :
2025-09-30 15:10:29 +08:00
return
2025-10-30 22:21:11 +08:00
remote_host = req_meta . remote_host
remote_block_ids = req_meta . remote_block_ids
remote_te_port = req_meta . remote_te_rpc_port
remote_kv_base_addrs = req_meta . remote_kv_caches_base_addr
local_kv_base_addr = self . kv_caches_base_addr
local_block_ids = req_meta . local_block_ids
2025-09-30 15:10:29 +08:00
2025-10-30 22:21:11 +08:00
if self . pd_head_ratio == 1 :
2025-09-30 15:10:29 +08:00
layer_local_kv_base_addr = [
2025-10-30 22:21:11 +08:00
local_kv_base_addr [ i ]
2025-09-30 15:10:29 +08:00
for i in [ 2 * layer_index , 2 * layer_index + 1 ]
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs [ i ]
for i in [ 2 * layer_index , 2 * layer_index + 1 ]
]
grouped_remote_block_ids , grouped_local_block_ids = \
group_concurrent_contiguous ( remote_block_ids , local_block_ids )
session_id = f " { remote_host } : { remote_te_port } "
src_list , dst_list , length_list = [ ] , [ ] , [ ]
for k , ( src_layer_base_addr , dst_layer_base_addr ) in enumerate (
zip ( layer_local_kv_base_addr , layer_remote_kv_base_addr ) ) :
block_len = self . block_len [
k % 2 ] if self . use_mla else self . block_len [ 0 ]
for group_remote_block_id , group_local_block_id in zip (
grouped_remote_block_ids , grouped_local_block_ids ) :
src = src_layer_base_addr + group_local_block_id [
0 ] * block_len
dst = dst_layer_base_addr + group_remote_block_id [
0 ] * block_len
length = len ( group_local_block_id ) * block_len
src_list . append ( src )
dst_list . append ( dst )
length_list . append ( length )
torch . npu . synchronize ( )
ret = self . engine . batch_transfer_sync_write (
session_id , src_list , dst_list , length_list )
if ret < 0 :
2025-10-30 22:21:11 +08:00
logger . error ( " Mooncake transfer failed for request %s " , req_id )
2025-09-30 15:10:29 +08:00
raise RuntimeError ( f " Mooncake transfer failed, ret: { ret } " )
else :
key = key . view ( - 1 , key . shape [ - 1 ] )
value = value . view ( - 1 , key . shape [ - 1 ] )
self . k_buffer [ : key . shape [ 0 ] ] . copy_ ( key ) # [:4, 128] ->
self . v_buffer [ : value . shape [ 0 ] ] . copy_ ( value )
layer_local_kv_base_addr = [
self . k_buffer . data_ptr ( ) ,
self . v_buffer . data_ptr ( )
]
layer_remote_kv_base_addr = [
remote_kv_base_addrs [ i ]
for i in [ 2 * layer_index , 2 * layer_index + 1 ]
]
grouped_remote_block_ids , _ = group_concurrent_contiguous (
remote_block_ids )
session_id = f " { remote_host } : { remote_te_port } "
src_list , dst_list , length_list = [ ] , [ ] , [ ]
for k , ( src_layer_base_addr , dst_layer_base_addr ) in enumerate (
zip ( layer_local_kv_base_addr , layer_remote_kv_base_addr ) ) :
src_layer_addr = src_layer_base_addr
for group_remote_block_id in grouped_remote_block_ids :
block_len = self . block_len [ 0 ]
2025-10-11 11:22:23 +08:00
remote_block_len = self . block_len [ 0 ] * self . pd_head_ratio
2025-09-30 15:10:29 +08:00
src_list . append ( src_layer_addr )
if src_layer_addr + len (
group_remote_block_id
) * block_len > src_layer_base_addr + key . numel (
) * key . element_size ( ) :
length = src_layer_base_addr + key . numel (
) * key . element_size ( ) - src_layer_addr
else :
length = len ( group_remote_block_id ) * block_len
length_list . append ( length )
dst_list . append ( dst_layer_base_addr +
group_remote_block_id [ 0 ] *
remote_block_len + length *
2025-10-11 11:22:23 +08:00
( ( self . tp_rank / / self . num_head_replica ) %
self . pd_head_ratio ) )
2025-09-30 15:10:29 +08:00
src_layer_addr + = length
torch . npu . synchronize ( )
ret = self . engine . batch_transfer_sync_write (
session_id , src_list , dst_list , length_list )
if ret < 0 :
2025-10-30 22:21:11 +08:00
logger . error ( " Mooncake transfer failed for request %s " , req_id )
2025-09-30 15:10:29 +08:00
raise RuntimeError ( f " Mooncake transfer failed, ret: { ret } " )
2025-10-30 22:21:11 +08:00
if layer_index == ( self . total_layers - 1 ) :
self . callback_func ( req_id , req_meta )
2025-09-30 15:10:29 +08:00
class KVCacheRecvingLayerThread ( threading . Thread ) :
def __init__ ( self , tp_rank : int , side_channel_port : int , tp_size : int ,
2025-10-30 22:21:11 +08:00
pd_head_ratio : int , local_engine_id : str ,
metadata : MooncakeAgentMetadata ,
ready_event : threading . Event ) :
2025-09-30 15:10:29 +08:00
super ( ) . __init__ ( daemon = True , name = " KVCacheRecvingLayerThread " )
self . tp_rank = tp_rank
self . tp_size = tp_size
2025-10-30 22:21:11 +08:00
self . pd_head_ratio = pd_head_ratio
2025-09-30 15:10:29 +08:00
self . local_engine_id = local_engine_id
self . side_channel_host = get_ip ( )
self . side_channel_port = side_channel_port
self . lock = threading . Lock ( )
self . done_requests = set [ str ] ( )
2025-10-30 22:21:11 +08:00
self . task_tracker = dict [ str , int ] ( )
2025-09-30 15:10:29 +08:00
self . ready_event = ready_event
2025-10-30 22:21:11 +08:00
self . metadata = metadata
2025-09-30 15:10:29 +08:00
def get_and_clear_finished_requests ( self ) - > set [ str ] :
"""
Get and clear the requests that have been completed .
Returns :
A set of request IDs that have been completed .
"""
with self . lock :
finished_requests = self . done_requests
self . done_requests = set ( )
return finished_requests
2025-10-30 22:21:11 +08:00
def update_task ( self , req_id ) :
with self . lock :
self . task_tracker [ req_id ] + = 1
if self . task_tracker [ req_id ] == self . pd_head_ratio :
self . task_tracker . pop ( req_id )
self . done_requests . add ( req_id )
2025-09-30 15:10:29 +08:00
def run ( self ) :
""" Run the thread to handle KV cache transfer requests. """
handshake_port = self . side_channel_port + self . tp_rank
path = make_zmq_path ( " tcp " , self . side_channel_host , handshake_port )
logger . info ( " Starting listening on path: %s " , path )
2025-10-30 22:21:11 +08:00
encoder = msgspec . msgpack . Encoder ( )
encoded_data = encoder . encode ( self . metadata )
2025-09-30 15:10:29 +08:00
with zmq_ctx ( zmq . ROUTER , path ) as sock : # type: ignore
self . ready_event . set ( )
2025-10-30 22:21:11 +08:00
decoder = msgspec . msgpack . Decoder ( type = tuple )
2025-09-30 15:10:29 +08:00
while True :
try :
frames = sock . recv_multipart ( )
if len ( frames ) < 2 :
logger . error ( " Invalid message format: %s " , frames )
continue
identity = frames [ 0 ]
payload = [ f for f in frames [ 1 : ] if f != b " " ]
if len ( payload ) != 1 :
logger . error ( " Invalid message format: %s " , frames )
continue
2025-10-30 22:21:11 +08:00
msg = decoder . decode ( payload [ 0 ] )
if msg [ 0 ] == GET_META_MSG :
logger . info ( " Got GET META INFO for request %s " , msg [ 0 ] )
sock . send_multipart ( ( identity , b " " , encoded_data ) )
elif msg [ 0 ] == DONE_SENDING_MSG :
logger . debug ( " Got DONE_RECVING_MSG for request %s " ,
msg [ 1 ] )
request_id = msg [ 1 ]
self . update_task ( request_id )
sock . send_multipart ( ( identity , b " " , b " ACK " ) )
else :
logger . error (
" Connection listener got unexpected message %s " ,
msg )
2025-09-30 15:10:29 +08:00
except Exception as e :
logger . error ( " Failed to decode message: %s " , e )
class MooncakeLayerwiseConnectorMetadata ( KVConnectorMetadata ) :
def __init__ ( self ) :
self . requests : dict [ str , ReqMeta ] = { }
def add_new_req ( self ,
request_id : str ,
local_block_ids : list [ int ] ,
kv_transfer_params : dict [ str , Any ] ,
2025-10-30 22:21:11 +08:00
token_ids : Optional [ list [ int ] ] = None ) :
2025-09-30 15:10:29 +08:00
self . requests [ request_id ] = ReqMeta (
2025-10-30 22:21:11 +08:00
token_ids = token_ids or [ ] ,
2025-09-30 15:10:29 +08:00
local_block_ids = local_block_ids ,
2025-10-30 22:21:11 +08:00
remote_block_ids = kv_transfer_params . get ( " remote_block_ids " , [ ] ) ,
remote_engine_id = kv_transfer_params . get ( " remote_engine_id " , None ) ,
remote_host = kv_transfer_params . get ( " remote_host " , None ) ,
remote_port = kv_transfer_params . get ( " remote_port " , None ) ,
remote_te_rpc_port = kv_transfer_params . get ( " remote_te_rpc_port " ,
None ) ,
remote_kv_caches_base_addr = kv_transfer_params . get (
" remote_kv_caches_base_addr " , None ) ,
metaserver = kv_transfer_params . get ( " metaserver " , None ) ,
2025-09-30 15:10:29 +08:00
)
class MooncakeLayerwiseConnector ( KVConnectorBase_V1 ) :
def __init__ ( self , vllm_config : VllmConfig , role : KVConnectorRole ) :
assert vllm_config . kv_transfer_config is not None
self . engine_id = vllm_config . kv_transfer_config . engine_id
2025-10-30 22:21:11 +08:00
self . _connector_metadata = MooncakeLayerwiseConnectorMetadata ( )
2025-09-30 15:10:29 +08:00
if role == KVConnectorRole . SCHEDULER :
self . connector_scheduler : Optional [ MooncakeLayerwiseConnectorScheduler ] = \
MooncakeLayerwiseConnectorScheduler ( vllm_config , str ( self . engine_id ) )
self . connector_worker : Optional [
MooncakeLayerwiseConnectorWorker ] = None
elif role == KVConnectorRole . WORKER :
self . connector_scheduler = None
self . connector_worker = MooncakeLayerwiseConnectorWorker (
vllm_config , str ( self . engine_id ) )
############################################################
# Scheduler Side Methods
############################################################
def get_num_new_matched_tokens (
self , request : " Request " ,
num_computed_tokens : int ) - > tuple [ int , bool ] :
assert self . connector_scheduler is not None
return self . connector_scheduler . get_num_new_matched_tokens (
request , num_computed_tokens )
def update_state_after_alloc ( self , request : " Request " ,
blocks : " KVCacheBlocks " ,
num_external_tokens : int ) :
assert self . connector_scheduler is not None
return self . connector_scheduler . update_state_after_alloc (
request , blocks , num_external_tokens )
def build_connector_meta (
self ,
scheduler_output : SchedulerOutput ,
) - > KVConnectorMetadata :
assert self . connector_scheduler is not None
return self . connector_scheduler . build_connector_meta ( scheduler_output )
def request_finished (
self ,
request : " Request " ,
block_ids : list [ int ] ,
) - > tuple [ bool , Optional [ dict [ str , Any ] ] ] :
assert self . connector_scheduler is not None
return self . connector_scheduler . request_finished ( request , block_ids )
############################################################
# Worker Side Methods
############################################################
def register_kv_caches ( self , kv_caches : dict [ str , torch . Tensor ] ) :
assert self . connector_worker is not None
self . connector_worker . register_kv_caches ( kv_caches )
def get_finished ( self ,
finished_req_ids : set [ str ] ) - > tuple [ set [ str ] , set [ str ] ] :
""" Get the finished recving and sending requests. """
assert self . connector_worker is not None
return self . connector_worker . get_finished ( )
def start_load_kv ( self , forward_context : " ForwardContext " ,
* * kwargs ) - > None :
assert self . connector_worker is not None
assert isinstance ( self . _connector_metadata ,
MooncakeLayerwiseConnectorMetadata )
self . connector_worker . start_load_kv ( self . _connector_metadata )
def wait_for_layer_load ( self , layer_name : str ) - > None :
""" MooncakeLayerwiseConnector does not do layerwise saving. """
assert self . connector_worker is not None
assert isinstance ( self . _connector_metadata ,
MooncakeLayerwiseConnectorMetadata )
self . connector_worker . wait_for_layer_load ( layer_name )
def save_kv_layer ( self , layer_name : str , kv_layer : torch . Tensor ,
attn_metadata : " AttentionMetadata " , * * kwargs ) - > None :
""" MooncakeLayerwiseConnector does not save explicitly. """
assert self . connector_worker is not None
assert isinstance ( self . _connector_metadata ,
MooncakeLayerwiseConnectorMetadata )
self . connector_worker . save_kv_layer ( layer_name , kv_layer ,
attn_metadata ,
self . _connector_metadata )
def wait_for_save ( self ) :
""" MooncakeLayerwiseConnector does not save explicitly. """
pass
class MooncakeLayerwiseConnectorScheduler :
""" Implementation of Scheduler side methods """
def __init__ ( self , vllm_config : VllmConfig , engine_id : str ) :
self . vllm_config = vllm_config
self . block_size = vllm_config . cache_config . block_size
self . engine_id = engine_id
logger . info ( " Initializing Mooncake Scheduler %s " , engine_id )
self . side_channel_host = get_ip ( )
self . max_device_id = vllm_config . parallel_config . tensor_parallel_size * \
vllm_config . parallel_config . data_parallel_size
# Handshake base port
self . side_channel_port = (
vllm_config . kv_transfer_config . kv_port +
2025-10-30 22:21:11 +08:00
vllm_config . parallel_config . data_parallel_rank *
2025-09-30 15:10:29 +08:00
vllm_config . parallel_config . tensor_parallel_size )
# Requests that need to start recv.
# New requests are added by update_state_after_alloc in
# the scheduler. Used to make metadata passed to Worker.
2025-10-30 22:21:11 +08:00
self . _reqs_need_recv : dict [ str , tuple [ Request , list [ int ] ,
list [ int ] ] ] = { }
self . _reqs_need_send_layerwise : dict [ str , tuple [
int , list [ int ] ,
Request ] ] = { } # req_id, (len(prompt), local_block_ids, request)
2025-09-30 15:10:29 +08:00
def get_num_new_matched_tokens (
self , request : " Request " ,
num_computed_tokens : int ) - > tuple [ int , bool ] :
"""
For remote prefill , pull all prompt blocks from remote
asynchronously relative to engine execution .
Args :
request ( Request ) : the request object .
num_computed_tokens ( int ) : the number of locally
computed tokens for this request
Returns :
* the number of tokens that can be loaded from the
external KV cache beyond what is already computed .
* true if the external KV cache tokens will be loaded
asynchronously ( between scheduler steps ) .
"""
params = request . kv_transfer_params
logger . debug (
" MooncakeLayerwiseConnector get_num_new_matched_tokens: "
" num_computed_tokens= %s , kv_transfer_params= %s " ,
num_computed_tokens , params )
if params is not None and params . get ( " do_remote_prefill " ) :
2025-10-30 22:21:11 +08:00
# Remote prefill: get all prompt blocks from remote.
assert num_computed_tokens % self . block_size == 0
# Note: We use the full token count as transmit data here.
count = max ( len ( request . prompt_token_ids ) - num_computed_tokens , 0 )
return count , count > 0
2025-09-30 15:10:29 +08:00
# No remote prefill for this request.
return 0 , False
def update_state_after_alloc ( self , request : " Request " ,
blocks : " KVCacheBlocks " ,
num_external_tokens : int ) :
params = request . kv_transfer_params
logger . debug (
" MooncakeLayerwiseConnector update_state_after_alloc: "
" num_external_tokens= %s , kv_transfer_params= %s " ,
num_external_tokens , params )
if params is not None and params . get ( " do_remote_prefill " ) :
2025-10-30 22:21:11 +08:00
local_block_ids = ( blocks . get_unhashed_block_ids ( )
if num_external_tokens > 0 else [ ] )
# Get unhashed blocks to pull from remote.
self . _reqs_need_recv [ request . request_id ] = (
request ,
[ ] , #request._all_token_ids,
local_block_ids )
2025-09-30 15:10:29 +08:00
params [ " do_remote_prefill " ] = False
# Layerwise prefiller add request need send
if params is not None and params . get ( " do_remote_decode " ) :
local_block_ids = ( blocks . get_block_ids ( ) [ 0 ] )
2025-10-30 22:21:11 +08:00
self . _reqs_need_send_layerwise [ request . request_id ] = ( len (
request . all_token_ids ) , local_block_ids , request )
2025-09-30 15:10:29 +08:00
def build_connector_meta (
self ,
scheduler_output : SchedulerOutput ,
) - > KVConnectorMetadata :
meta = MooncakeLayerwiseConnectorMetadata ( )
# Loop through scheduled reqs and convert to ReqMeta.
2025-10-30 22:21:11 +08:00
for req_id , ( req , token_ids ,
block_ids ) in self . _reqs_need_recv . items ( ) :
2025-09-30 15:10:29 +08:00
assert req . kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
2025-10-30 22:21:11 +08:00
meta . add_new_req ( request_id = req_id ,
local_block_ids = block_ids ,
kv_transfer_params = req . kv_transfer_params ,
token_ids = token_ids )
2025-09-30 15:10:29 +08:00
# Clear the list once workers start the transfers
self . _reqs_need_recv . clear ( )
cached_reqs = scheduler_output . scheduled_cached_reqs
new_reqs = scheduler_output . scheduled_new_reqs
for req_id , new_blocks in zip ( cached_reqs . req_ids ,
cached_reqs . new_block_ids ) :
if req_id in self . _reqs_need_send_layerwise and new_blocks is not None :
2025-10-30 22:21:11 +08:00
total_tokens , block_ids , req = self . _reqs_need_send_layerwise [
2025-09-30 15:10:29 +08:00
req_id ]
block_ids . extend ( new_blocks [ 0 ] )
computed_tokens = dict (
list ( zip ( cached_reqs . req_ids , cached_reqs . num_computed_tokens ) ) +
[ ( x . req_id , x . num_computed_tokens ) for x in new_reqs ] )
for req_id , scheduled_tokens in scheduler_output . num_scheduled_tokens . items (
) :
if req_id in self . _reqs_need_send_layerwise :
2025-10-30 22:21:11 +08:00
total_tokens , block_ids , req = self . _reqs_need_send_layerwise [
2025-09-30 15:10:29 +08:00
req_id ]
current_tokens = computed_tokens . get ( req_id ,
0 ) + scheduled_tokens
if current_tokens == total_tokens :
2025-10-30 22:21:11 +08:00
meta . add_new_req ( request_id = req_id ,
local_block_ids = block_ids ,
kv_transfer_params = req . kv_transfer_params ,
token_ids = [ ] )
2025-09-30 15:10:29 +08:00
self . _reqs_need_send_layerwise . pop ( req_id )
return meta
def request_finished (
self ,
request : " Request " ,
block_ids : list [ int ] ,
) - > tuple [ bool , Optional [ dict [ str , Any ] ] ] :
"""
Once a request is finished , determine whether request blocks
should be freed now or will be sent asynchronously and freed later .
"""
2025-10-30 22:21:11 +08:00
# layer_wise push, not need delay_free_blocks
return False , None
2025-09-30 15:10:29 +08:00
class MooncakeLayerwiseConnectorWorker :
""" Implementation of Worker side methods """
def __init__ ( self , vllm_config : VllmConfig , engine_id : str ) :
self . _get_prefill_decode_size ( vllm_config )
if self . _prefill_tp_size < self . _decode_tp_size :
raise ValueError (
f " prefill_tp_size: { self . _prefill_tp_size } must be greater than "
f " or equal to the decode_tp_size: { self . _decode_tp_size } " )
if TransferEngine is None :
raise RuntimeError ( " mooncake is not available " )
logger . info ( " Initializing Mooncake work %s " , engine_id )
self . engine = TransferEngine ( )
# Metadata.
self . vllm_config = vllm_config
2025-10-30 22:21:11 +08:00
self . local_engine_id : str = " "
2025-09-30 15:10:29 +08:00
self . engine_id = engine_id
self . tp_rank = get_tensor_model_parallel_rank ( )
self . tp_size = vllm_config . parallel_config . tensor_parallel_size
self . tp_group = get_tp_group ( )
2025-10-30 22:21:11 +08:00
self . dp_rank = vllm_config . parallel_config . data_parallel_rank
2025-09-30 15:10:29 +08:00
self . dp_size = vllm_config . parallel_config . data_parallel_size_local
self . kv_caches : dict [ str , torch . Tensor ] = { }
self . side_channel_host = get_ip ( )
self . max_device_id = self . tp_size * self . dp_size
self . total_layers = vllm_config . model_config . get_num_layers (
vllm_config . parallel_config )
2025-10-30 22:21:11 +08:00
self . executor = ThreadPoolExecutor ( 32 )
2025-09-30 15:10:29 +08:00
self . metaserver_client = httpx . Client (
limits = httpx . Limits ( max_connections = 100000 ) ,
timeout = None ) if self . tp_rank == 0 else None
# Handshake base port
self . side_channel_port = (
vllm_config . kv_transfer_config . kv_port +
2025-10-30 22:21:11 +08:00
vllm_config . parallel_config . data_parallel_rank *
2025-09-30 15:10:29 +08:00
vllm_config . parallel_config . tensor_parallel_size )
self . handshake_port = self . side_channel_port + self . tp_rank
self . sockets : dict = { }
# get tp device id
# TODO(kw): https://github.com/vllm-project/vllm-ascend/pull/940
# introducing some changes
device_ids_str = envs_ascend . PHYSICAL_DEVICES
if device_ids_str is None :
device_ids = list (
range ( self . dp_rank * self . tp_size ,
( self . dp_rank + 1 ) * self . tp_size ) )
else :
device_ids = list ( map ( int , device_ids_str . split ( ' , ' ) ) )
start_index = self . dp_rank * self . tp_size
end_index = start_index + self . tp_size
if len ( device_ids ) < end_index :
raise ValueError (
f " Not enough physical devices available for DP rank { self . dp_rank } . "
f " Expected at least { end_index } devices, but found { len ( device_ids ) } "
" in PHYSICAL_DEVICES. " )
device_ids = device_ids [ start_index : end_index ]
assert len ( device_ids ) > self . tp_rank # type: ignore
self . device_id = device_ids [ self . tp_rank ] # type: ignore
if vllm_config . kv_transfer_config . get_from_extra_config (
2025-10-30 23:55:04 +08:00
' use_ascend_direct ' , False ) :
2025-09-30 15:10:29 +08:00
hostname = self . side_channel_host
else :
hostname = f " { self . side_channel_host } :0:npu_ { self . device_id } "
self . _initialize ( hostname = hostname , device_name = None )
self . te_rpc_port = self . engine . get_rpc_port ( )
# Background thread for sending or receiving KV caches.
self . kv_recv_layer_thread : Optional [ KVCacheRecvingLayerThread ] = None
2025-10-30 22:21:11 +08:00
self . kv_send_layer_thread : Optional [ KVCacheSendingLayerThread ] = None
2025-09-30 15:10:29 +08:00
self . vllm_config = vllm_config
self . block_size = vllm_config . cache_config . block_size
self . kv_caches_base_addr : list [ int ] = [ ]
self . pd_tp_ratio = get_ascend_config ( ) . pd_tp_ratio
2025-10-11 11:22:23 +08:00
self . pd_head_ratio = get_ascend_config ( ) . pd_head_ratio
2025-10-30 22:21:11 +08:00
self . num_head_replica = get_ascend_config ( ) . num_head_replica
2025-10-11 11:22:23 +08:00
2025-09-30 15:10:29 +08:00
self . first_kv_cache = None
2025-10-30 22:21:11 +08:00
self . remote_poller = zmq . Poller ( ) # type: ignore
self . decoder = msgspec . msgpack . Decoder ( MooncakeAgentMetadata )
self . encoder = msgspec . msgpack . Encoder ( )
self . remote_kv_caches_base_addr : dict [ str , dict [ int , list [ int ] ] ] = \
defaultdict ( dict )
self . remote_te_port : dict [ str , dict [ int , int ] ] = \
defaultdict ( dict )
self . remote_sockets_lock = threading . Lock ( )
self . remote_sockets : dict [ # type: ignore
str , deque [ zmq . Socket ] ] = defaultdict ( # type: ignore
deque )
self . remote_poller = zmq . Poller ( ) # type: ignore
self . timeout = 1.0 # seconds
2025-09-30 15:10:29 +08:00
def _get_prefill_decode_size ( self , vllm_config : VllmConfig ) :
# get prefill tp and dp size from extra config
prefill_parallel_config : dict [
str , Any ] = vllm_config . kv_transfer_config . get_from_extra_config (
" prefill " , { } )
assert " tp_size " in prefill_parallel_config . keys ( )
self . _prefill_tp_size = prefill_parallel_config [ " tp_size " ]
assert " dp_size " in prefill_parallel_config . keys ( )
self . _prefill_dp_size = prefill_parallel_config [ " dp_size " ]
# get decode tp and dp size from extra config
decode_parallel_config : dict [
str , Any ] = vllm_config . kv_transfer_config . get_from_extra_config (
" decode " , { } )
assert " tp_size " in decode_parallel_config . keys ( )
self . _decode_tp_size = decode_parallel_config [ " tp_size " ]
assert " dp_size " in decode_parallel_config . keys ( )
self . _decode_dp_size = decode_parallel_config [ " dp_size " ]
def _initialize (
self ,
hostname : str ,
device_name : Optional [ str ] ,
) - > None :
""" Initialize the mooncake instance. """
device_name = device_name if device_name is not None else " "
ret_value = self . engine . initialize ( hostname , " P2PHANDSHAKE " , " ascend " ,
device_name )
if ret_value != 0 :
raise RuntimeError (
f " Mooncake initialization failed with ret_value: { ret_value } " )
def register_kv_caches ( self , kv_caches : dict [ str , torch . Tensor ] ) :
""" Register the KV Cache data. """
_ , first_kv_cache_tuple = next ( iter ( kv_caches . items ( ) ) )
first_kv_cache = first_kv_cache_tuple [ 0 ]
self . first_kv_cache = first_kv_cache
# TODO(tms): Find a more robust way to detect and handle MLA
self . use_mla = first_kv_cache_tuple [ 0 ] . size (
- 1 ) != first_kv_cache_tuple [ 1 ] . size ( - 1 )
if self . use_mla :
# MLA case.[num_block, block_size, 1, hidden_dim]
self . num_blocks = first_kv_cache . shape [ 0 ]
block_rank = 3 # [block_size, latent_dim]
block_shape_norm = first_kv_cache_tuple [ 0 ] . shape [ - block_rank : ]
block_shape_pe = first_kv_cache_tuple [ 1 ] . shape [ - block_rank : ]
self . block_len = [
first_kv_cache [ 0 ] . element_size ( ) * math . prod ( block_shape_norm ) ,
first_kv_cache [ 1 ] . element_size ( ) * math . prod ( block_shape_pe )
]
logger . info (
" num_blocks: %s , block_shape_norm: %s , block_shape_pe: %s " ,
self . num_blocks , block_shape_norm , block_shape_pe )
else :
# [num_block, block_size, num_head, hidden_dim]
self . num_blocks = first_kv_cache . shape [ 0 ]
kv_elem_size = first_kv_cache . element_size ( )
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache . shape [ - block_rank : ]
self . block_len = [ kv_elem_size * math . prod ( block_shape ) ]
logger . info ( " num_blocks: %s , block_shape: %s " , self . num_blocks ,
block_shape )
logger . info ( " Registering KV_Caches. use_mla: %s , shape %s " ,
self . use_mla , first_kv_cache . shape )
self . kv_caches = kv_caches
kv_caches_base_addr = [ ]
for cache_or_caches in kv_caches . values ( ) :
# Normalize to always be a list of caches
if self . use_mla :
for i , cache in enumerate ( cache_or_caches , 0 ) :
base_addr = cache . data_ptr ( )
region_len = self . num_blocks * self . block_len [ i % 2 ]
kv_caches_base_addr . append ( base_addr )
self . _register ( base_addr , region_len )
else :
cache_list = [ cache_or_caches
] if self . use_mla else cache_or_caches
for cache in cache_list :
base_addr = cache . data_ptr ( )
region_len = self . num_blocks * self . block_len [ 0 ]
kv_caches_base_addr . append ( base_addr )
self . _register ( base_addr , region_len )
self . kv_caches_base_addr = kv_caches_base_addr
# After KV Caches registered, start the sending or receiving thread.
metadata = MooncakeAgentMetadata (
te_rpc_port = self . te_rpc_port ,
2025-10-30 22:21:11 +08:00
kv_caches_base_addr = self . kv_caches_base_addr ,
2025-09-30 15:10:29 +08:00
)
2025-10-30 22:21:11 +08:00
if self . vllm_config . kv_transfer_config . is_kv_producer :
ready_event = threading . Event ( )
2025-09-30 15:10:29 +08:00
self . kv_send_layer_thread = KVCacheSendingLayerThread (
2025-10-30 22:21:11 +08:00
engine = self . engine ,
total_layers = self . total_layers ,
ready_event = ready_event ,
tp_rank = self . tp_rank ,
pd_head_ratio = self . pd_head_ratio ,
num_head_replica = self . num_head_replica ,
kv_cache_base_addr = self . kv_caches_base_addr ,
use_mla = self . use_mla ,
block_len = self . block_len ,
first_kv_cache = first_kv_cache ,
callback_func = self . send_done_send_signal )
2025-09-30 15:10:29 +08:00
self . kv_send_layer_thread . start ( )
2025-10-30 22:21:11 +08:00
ready_event . wait ( )
if self . vllm_config . kv_transfer_config . is_kv_consumer :
ready_event = threading . Event ( )
2025-09-30 15:10:29 +08:00
self . kv_recv_layer_thread = KVCacheRecvingLayerThread (
self . tp_rank , self . side_channel_port , self . tp_size ,
2025-10-30 22:21:11 +08:00
self . pd_head_ratio , self . engine_id , metadata , ready_event )
2025-09-30 15:10:29 +08:00
self . kv_recv_layer_thread . start ( )
2025-10-30 22:21:11 +08:00
ready_event . wait ( )
2025-09-30 15:10:29 +08:00
def _register ( self , ptr , length ) :
logger . info (
" Registering KV cache: ptr=0x %x , length= %d , num_blocks= %d , "
" block_lens= %s " , ptr , length , self . num_blocks , self . block_len )
ret_value = self . engine . register_memory ( ptr , length )
if ret_value != 0 :
raise RuntimeError ( " Mooncake memory registration failed. " )
def _access_metaserver ( self , url , message ) :
2025-10-30 22:21:11 +08:00
success = False
retry = 0
while retry < 3 and success is False :
retry + = 1
try :
self . metaserver_client . post ( url , json = message )
success = True
except Exception as e :
logger . error (
f " Failed to connect to metaserver: { url } , retry { retry } time. "
)
if retry == 3 :
raise e
2025-09-30 15:10:29 +08:00
def get_finished ( self ) - > tuple [ set [ str ] , set [ str ] ] :
done_recving = (
self . kv_recv_layer_thread .
get_and_clear_finished_requests ( # type: ignore[union-attr]
2025-10-30 22:21:11 +08:00
) if self . vllm_config . kv_transfer_config . is_kv_consumer else set ( ) )
if len ( done_recving ) > 0 :
logger . info (
" Number of completed KV cache recv requests: %d , receive "
" requests: %d " , 0 , len ( done_recving ) )
return set ( ) , done_recving
2025-09-30 15:10:29 +08:00
def start_load_kv ( self , metadata : MooncakeLayerwiseConnectorMetadata ) :
""" Start loading KV blocks from remote engine. """
self . current_layer = 0
2025-10-30 22:21:11 +08:00
if self . vllm_config . kv_transfer_config . is_kv_consumer :
2025-09-30 15:10:29 +08:00
for req_id , meta in metadata . requests . items ( ) :
2025-10-30 22:21:11 +08:00
if self . tp_rank % self . tp_size == 0 :
logger . info (
f " Send request: { req_id } to proxy metaserver: { meta . metaserver } "
)
2025-09-30 15:10:29 +08:00
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict (
2025-10-30 22:21:11 +08:00
token_ids = meta . token_ids ,
2025-09-30 15:10:29 +08:00
request_id = req_id ,
2025-10-30 22:21:11 +08:00
do_remote_prefill = False ,
do_remote_decode = True ,
remote_block_ids = meta . local_block_ids ,
2025-09-30 15:10:29 +08:00
remote_engine_id = self . engine_id ,
remote_host = self . side_channel_host ,
2025-10-30 22:21:11 +08:00
remote_port = self . side_channel_port ,
)
2025-09-30 15:10:29 +08:00
future = self . executor . submit (
self . _access_metaserver ,
url = meta . metaserver ,
message = kv_transfer_params ,
)
def handle_exception ( future ) :
if future . exception ( ) :
logger . error (
f " Access metaserver fail: { future . exception ( ) } "
)
future . add_done_callback ( handle_exception )
2025-10-30 22:21:11 +08:00
assert self . kv_recv_layer_thread is not None
with self . kv_recv_layer_thread . lock :
self . kv_recv_layer_thread . task_tracker [ req_id ] = 0
2025-09-30 15:10:29 +08:00
def save_kv_layer ( self , layer_name : str , kv_layer : Tuple [ torch . Tensor ,
torch . Tensor ] ,
attn_metadata : " AttentionMetadata " ,
connector_metadata : MooncakeLayerwiseConnectorMetadata ,
* * kwargs ) - > None :
""" MooncakeLayerwiseConnector does not save explicitly. """
2025-10-30 22:21:11 +08:00
if self . vllm_config . kv_transfer_config . is_kv_producer and connector_metadata . requests . keys (
2025-10-29 23:14:00 +08:00
) :
2025-10-30 22:21:11 +08:00
# enable decode prefix cache
for request in connector_metadata . requests . values ( ) :
assert len ( request . local_block_ids ) > = len (
request . remote_block_ids
) , " When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num. "
request . local_block_ids = request . local_block_ids [
- len ( request . remote_block_ids ) : ]
2025-10-11 11:22:23 +08:00
if self . pd_head_ratio != 1 :
2025-09-30 15:10:29 +08:00
def sort_kv_cache ( input_kv : list [ list [ int ] ] ) :
return torch . cat ( [
2025-10-11 11:22:23 +08:00
torch . chunk ( tensor , self . pd_head_ratio , dim = 0 ) [ x ]
for x in range ( self . pd_head_ratio )
for tensor in input_kv
2025-09-30 15:10:29 +08:00
] )
total_block_ids = [
request . local_block_ids
for request in connector_metadata . requests . values ( )
]
keys = [
kv_layer [ 0 ] [ block_ids ] . reshape (
- 1 , * kv_layer [ 0 ] . shape [ 2 : ] ) . clone ( )
for block_ids in total_block_ids
]
values = [
kv_layer [ 1 ] [ block_ids ] . reshape (
- 1 , * kv_layer [ 1 ] . shape [ 2 : ] ) . clone ( )
for block_ids in total_block_ids
]
key_block_size = keys [ 0 ] . size ( 0 ) / / len ( total_block_ids [ 0 ] )
value_block_size = values [ 0 ] . size ( 0 ) / / len ( total_block_ids [ 0 ] )
keys = sort_kv_cache ( keys ) # [req1_key, req2_key]
values = sort_kv_cache ( values )
( keys ,
2025-10-11 11:22:23 +08:00
values ) = kv_alltoall_and_rearrange ( self . pd_head_ratio , keys ,
2025-09-30 15:10:29 +08:00
values )
key_start_id = 0
value_start_id = 0
else :
key = None
value = None
2025-10-30 22:21:11 +08:00
for req_id , req_meta in connector_metadata . requests . items ( ) :
logger . debug (
f " Add request { req_id } to kv send layer thread. { req_meta =} "
)
2025-10-11 11:22:23 +08:00
if self . pd_head_ratio != 1 :
2025-09-30 15:10:29 +08:00
key_block_num = len (
2025-10-30 22:21:11 +08:00
req_meta . local_block_ids ) * key_block_size
2025-09-30 15:10:29 +08:00
value_block_num = len (
2025-10-30 22:21:11 +08:00
req_meta . local_block_ids ) * value_block_size
key = keys [ key_start_id : key_start_id + key_block_num ]
2025-09-30 15:10:29 +08:00
value = values [ value_start_id : value_start_id +
2025-10-30 22:21:11 +08:00
value_block_num ]
2025-09-30 15:10:29 +08:00
key_start_id + = key_block_num
value_start_id + = value_block_num
2025-10-30 22:21:11 +08:00
req_meta_update = self . update_decoder_info ( req_id , req_meta )
assert self . kv_send_layer_thread is not None
self . kv_send_layer_thread . send_queue . put (
( req_id , req_meta_update , self . current_layer , key , value ) )
2025-09-30 15:10:29 +08:00
self . current_layer + = 1
2025-10-30 22:21:11 +08:00
def _get_remote_socket (
self , remote_host : str ,
remote_handshake_port : int ) - > zmq . Socket : # type: ignore
""" Get a socket to the remote host. """
remote_path = make_zmq_path ( " tcp " , remote_host , remote_handshake_port )
with self . remote_sockets_lock :
if self . remote_sockets [ remote_path ] :
return self . remote_sockets [ remote_path ] . popleft ( )
ctx = zmq . Context ( ) # type: ignore
sock = make_zmq_socket (
ctx = ctx ,
path = remote_path ,
socket_type = zmq . REQ , # type: ignore
bind = False )
sock . setsockopt (
zmq . SNDTIMEO , # type: ignore
int ( self . timeout * 1000 ) )
self . remote_poller . register ( sock , zmq . POLLIN ) # type: ignore
return sock
def update_decoder_info ( self , req_id , req_meta ) :
req_meta_update = copy . deepcopy ( req_meta )
if self . pd_tp_ratio > 1 :
req_meta_update . remote_port = req_meta_update . remote_port + self . tp_rank / / self . pd_tp_ratio
else :
req_meta_update . remote_port = req_meta_update . remote_port + self . tp_rank
if req_meta_update . remote_engine_id not in self . remote_kv_caches_base_addr or \
req_meta_update . remote_port not in self . remote_kv_caches_base_addr [ req_meta_update . remote_engine_id ] :
try :
encoded_data = self . encoder . encode ( ( GET_META_MSG , req_id ) )
sock = self . _get_remote_socket ( req_meta_update . remote_host ,
req_meta_update . remote_port )
ensure_zmq_send ( sock , encoded_data )
metadata_bytes = ensure_zmq_recv ( sock , self . remote_poller )
agent_meta = self . decoder . decode ( metadata_bytes )
except Exception as e :
logger . error (
f " Query to port and kv base addr for request { req_id } from { req_meta_update . remote_host } : { req_meta_update . remote_port } fail with error: { e } "
)
assert req_meta_update . remote_engine_id != self . engine_id , (
f " Conflict engine id { req_meta_update . remote_engine_id } with local engine id "
f " { self . local_engine_id } . " )
self . remote_kv_caches_base_addr [ req_meta_update . remote_engine_id ] [
req_meta_update . remote_port ] = agent_meta . kv_caches_base_addr
self . remote_te_port [ req_meta_update . remote_engine_id ] [
req_meta_update . remote_port ] = agent_meta . te_rpc_port
logger . info (
f " Query to port and kv base addr for request { req_id } from { req_meta_update . remote_host } : { req_meta_update . remote_port } success { agent_meta . kv_caches_base_addr =} { agent_meta . te_rpc_port =} "
)
req_meta_update . remote_te_rpc_port = self . remote_te_port [
req_meta_update . remote_engine_id ] [ req_meta_update . remote_port ]
req_meta_update . remote_kv_caches_base_addr = self . remote_kv_caches_base_addr [
req_meta_update . remote_engine_id ] [ req_meta_update . remote_port ]
return req_meta_update
def send_done_send_signal ( self , req_id , req_meta ) :
logger . info ( " Sending done sending signal for request %s to %s : %d " ,
req_id , req_meta . remote_host , req_meta . remote_port )
try :
path = make_zmq_path ( " tcp " , req_meta . remote_host ,
req_meta . remote_port )
msg_encoder = msgspec . msgpack . Encoder ( )
encoded_data = msg_encoder . encode ( ( DONE_SENDING_MSG , req_id ) )
with zmq_ctx ( zmq . REQ , path ) as sock : # type: ignore
ensure_zmq_send ( sock , encoded_data )
ack = sock . recv ( )
if ack != b " ACK " :
raise ValueError ( f " Unexpected ACK response: { ack } " )
except Exception as e :
logger . error (
f " Sending done sending signal for request { req_id } to { req_meta . remote_host } : { req_meta . remote_port } fail with error: { e } "
)
2025-09-30 15:10:29 +08:00
def wait_for_layer_load ( self , layer_name : str ) - > None :
pass
@contextlib.contextmanager
def zmq_ctx ( socket_type : Any ,
addr : str ) - > Iterator [ zmq . Socket ] : # type: ignore
""" Context manager for a ZMQ socket """
if socket_type not in ( zmq . ROUTER , zmq . REQ , zmq . DEALER ) : # type: ignore
raise ValueError ( f " Unexpected socket type: { socket_type } " )
ctx : Optional [ zmq . Context ] = None # type: ignore
try :
ctx = zmq . Context ( ) # type: ignore
yield make_zmq_socket ( ctx = ctx ,
path = addr ,
socket_type = socket_type ,
bind = socket_type == zmq . ROUTER ) # type: ignore
finally :
if ctx is not None :
ctx . destroy ( linger = 0 )
def group_concurrent_contiguous (
src : List [ int ] ,
dst : List [ int ] = [ ]
) - > Tuple [ List [ npt . NDArray [ np . int64 ] ] , List [ npt . NDArray [ np . int64 ] ] ] :
""" Vectorised NumPy implementation. """
if not dst :
src_only_indices : npt . NDArray [ np . int64 ] = np . array ( src , dtype = np . int64 )
if src_only_indices . size == 0 :
return [ ] , [ ]
brk = np . where ( ( np . diff ( src_only_indices ) != 1 ) ) [ 0 ] + 1
src_groups = np . split ( src_only_indices , brk )
src_groups = [ g . tolist ( ) for g in src_groups ]
return src_groups , [ ]
else :
src_indices : npt . NDArray [ np . int64 ] = np . array ( src , dtype = np . int64 )
dst_indices : npt . NDArray [ np . int64 ] = np . array ( dst , dtype = np . int64 )
if src_indices . size == 0 :
return [ ] , [ ]
brk = np . where ( ( np . diff ( src_indices ) != 1 )
| ( np . diff ( dst_indices ) != 1 ) ) [ 0 ] + 1
src_groups = np . split ( src_indices , brk )
dst_groups = np . split ( dst_indices , brk )
src_groups = [ g . tolist ( ) for g in src_groups ]
dst_groups = [ g . tolist ( ) for g in dst_groups ]
return src_groups , dst_groups
def string_to_int64_hash ( input_str ) :
"""
Hash the string using SHA - 256 and convert it into an int64 integer .
"""
hashed_bytes = hashlib . sha256 ( input_str . encode ( " utf-8 " ) ) . digest ( )
trunked_bytes = hashed_bytes [ : 8 ]
uint64_value = struct . unpack ( " <Q " , trunked_bytes ) [ 0 ]
return uint64_value
def ensure_zmq_send (
socket : zmq . Socket , # type: ignore
data : bytes ,
max_retries : int = 3 ) :
retries_left = max_retries
while True :
try :
socket . send ( data )
return
except zmq . ZMQError as e : # type: ignore
retries_left - = 1
if retries_left > 0 :
logger . warning (
f " Send failed: { e } , retrying... ( { retries_left } "
" attempts left) " )
time . sleep ( 0.1 )
else :
logger . error ( f " Send failed after all retries: { e } " )
raise RuntimeError ( f " Failed to send data after { max_retries } "
f " retries: { e } " )
def ensure_zmq_recv (
socket : zmq . Socket , # type: ignore
poller : zmq . Poller , # type: ignore
timeout : float = 1.0 ,
max_retries : int = 3 ) - > bytes :
retries_left = max_retries
while True :
try :
if dict ( poller . poll ( int ( timeout * 1000 ) ) ) : # milliseconds
data = socket . recv ( )
return data
else :
raise zmq . ZMQError ( " Receive timeout " ) # type: ignore
except zmq . ZMQError as e : # type: ignore
retries_left - = 1
if retries_left > 0 :
logger . warning ( f " Receive failed: { e } , retrying... "
f " ( { retries_left } attempts left) " )
time . sleep ( 0.1 )
else :
logger . error ( f " Receive failed after all retries: { e } " )
raise RuntimeError (
f " Failed to receive data after { max_retries } "
f " retries: { e } " )