[P/D] Improve the performance of Layerwise Connector (#5303)

### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
zxr2333
2025-12-31 15:09:01 +08:00
committed by GitHub
parent 7d5242faca
commit 46a1614387
5 changed files with 354 additions and 202 deletions

View File

@@ -65,6 +65,32 @@ class ReqMeta:
remote_te_rpc_port: Optional[int]
remote_kv_caches_base_addr: Optional[list[int]]
metaserver: Optional[str]
chunk_finish: Optional[bool]
@dataclass
class SendReqInfo:
local_block_ids: list[int]
remote_block_ids: List[int]
remote_cache_tokens: int
local_transferred_tokens: int
local_computed_tokens: int
request: "Request"
def extend_local_block_ids(self, new_block_ids: List[int]) -> None:
"""extend local block ids for this step"""
self.local_block_ids.extend(new_block_ids)
def update_computed_tokens(self, computed_tokens: int) -> None:
"""update local computen tokens for this step"""
self.local_computed_tokens = computed_tokens
def update_transferred_tokens(self, transferred_tokens: int) -> None:
"""update transferred tokens for this step"""
self.local_transferred_tokens = transferred_tokens
def unpack(self):
return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request
@dataclass
@@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread):
raise RuntimeError("Mooncake memory registration failed. ")
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
torch.Tensor]]()
torch.Tensor, torch.npu.Event]]()
self.ready_event = ready_event
self.callback_func = callback_func
@@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread):
torch.npu.set_device(device)
self.ready_event.set()
while True:
req_id, req_meta, layer_index, key, value = self.send_queue.get()
self._handle_request(req_id, req_meta, layer_index, key, value)
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
)
self._handle_request(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
def _handle_request(self, req_id, req_meta, layer_index, key, value):
def _handle_request(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
try:
logger.debug(
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value)
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
reshape_cache_event)
logger.debug(
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
)
@@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Failed to transfer KV cache for request "
f"{req_id}: {e}")
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
# send kv layer to remote
if len(req_meta.local_block_ids) == 0:
logger.debug(
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
)
return
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
reshape_cache_event):
# not need to send kv cache
if self.tp_rank % self.num_head_replica != 0:
logger.debug(
@@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread):
length_list.append(length)
if self.current_layer != layer_index:
self.current_layer = layer_index
self.model_stream.synchronize()
"""
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
This issue will be fixed in CANN version 8.5.rc1.
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
to resolve this issue before the 8.5.RC1 release.
"""
reshape_cache_event.synchronize()
ret = self.engine.batch_transfer_sync_write(
session_id, src_list, dst_list, length_list)
if ret < 0:
@@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread):
logger.error("Mooncake transfer failed for request %s", req_id)
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
if layer_index == (self.total_layers - 1):
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
self.callback_func(req_id, req_meta)
@@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
request_id: str,
local_block_ids: list[int],
kv_transfer_params: dict[str, Any],
token_ids: Optional[list[int]] = None):
token_ids: Optional[list[int]] = None,
chunk_finish: bool = False):
self.requests[request_id] = ReqMeta(
token_ids=token_ids or [],
local_block_ids=local_block_ids,
@@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
remote_kv_caches_base_addr=kv_transfer_params.get(
"remote_kv_caches_base_addr", None),
metaserver=kv_transfer_params.get("metaserver", None),
)
chunk_finish=chunk_finish)
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
@@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1):
vllm_config: VllmConfig,
role: KVConnectorRole,
kv_cache_config: Optional[KVCacheConfig] = None):
super().__init__(vllm_config, role, kv_cache_config)
assert vllm_config.kv_transfer_config is not None
self.engine_id = vllm_config.kv_transfer_config.engine_id
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
@@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler:
# the scheduler. Used to make metadata passed to Worker.
self._reqs_need_recv: dict[str, tuple[Request, list[int],
list[int]]] = {}
self._reqs_need_send_layerwise: dict[str, tuple[
int, list[int],
Request]] = {} # req_id, (len(prompt), local_block_ids, request)
self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {}
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000), timeout=None)
def get_num_new_matched_tokens(
self, request: "Request",
@@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler:
params["do_remote_prefill"] = False
logger.info(
f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=[],
request_id=request.request_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=params.get("metaserver", None),
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}")
future.add_done_callback(handle_exception)
# Layerwise prefiller add request need send
if params is not None and params.get("do_remote_decode"):
local_block_ids = (blocks.get_block_ids()[0])
logger.debug(
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
)
self._reqs_need_send_layerwise[request.request_id] = (len(
request.all_token_ids), local_block_ids, request)
remote_block_ids = copy.deepcopy(params["remote_block_ids"])
remote_cache_tokens = (
(len(request.all_token_ids) + self.block_size - 1) //
self.block_size - len(remote_block_ids)) * self.block_size
local_transferred_tokens = remote_cache_tokens
local_computed_tokens = 0
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
local_block_ids=local_block_ids,
remote_block_ids=remote_block_ids,
remote_cache_tokens=remote_cache_tokens,
local_transferred_tokens=local_transferred_tokens,
local_computed_tokens=local_computed_tokens,
request=request)
def build_connector_meta(
self,
@@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler:
) -> KVConnectorMetadata:
meta = MooncakeLayerwiseConnectorMetadata()
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, token_ids,
block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=token_ids)
if self.vllm_config.kv_transfer_config.is_kv_consumer:
# Loop through scheduled reqs and convert to ReqMeta.
for req_id, (req, token_ids,
block_ids) in self._reqs_need_recv.items():
assert req.kv_transfer_params is not None
# For the case where there are no remote blocks to pull
# (block_ids is empty), we don't need to schedule
# an async read on the worker side.
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=token_ids)
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
# Clear the list once workers start the transfers
self._reqs_need_recv.clear()
else:
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
# update local block ids
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
self._reqs_need_send_layerwise[
req_id].extend_local_block_ids(new_blocks[0])
cached_reqs = scheduler_output.scheduled_cached_reqs
new_reqs = scheduler_output.scheduled_new_reqs
for req_id, new_blocks in zip(cached_reqs.req_ids,
cached_reqs.new_block_ids):
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
req_id]
block_ids.extend(new_blocks[0])
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs])
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
):
if req_id in self._reqs_need_send_layerwise:
send_req_info = self._reqs_need_send_layerwise[req_id]
# update local computed tokens, not transfer spec decode tokens
spec_decode_tokens = len(
scheduled_spec_decode_tokens[req_id]) if (
req_id in scheduled_spec_decode_tokens) else 0
send_req_info.update_computed_tokens(
computed_tokens.get(req_id, 0) + scheduled_tokens -
spec_decode_tokens)
computed_tokens = dict(
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) +
[(x.req_id, x.num_computed_tokens) for x in new_reqs])
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
):
if req_id in self._reqs_need_send_layerwise:
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
req_id]
current_tokens = computed_tokens.get(req_id,
0) + scheduled_tokens
if current_tokens >= total_tokens:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
meta.add_new_req(request_id=req_id,
local_block_ids=block_ids,
kv_transfer_params=req.kv_transfer_params,
token_ids=[])
self._reqs_need_send_layerwise.pop(req_id)
else:
logger.debug(
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
)
def add_tranfer_task(req_id,
send_req_info: SendReqInfo,
chunk_finish=False):
local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack(
)
local_trans_block_ids = local_block_ids[(
local_transferred_tokens //
self.block_size):(local_computed_tokens //
self.block_size)]
remote_trans_block_ids = remote_block_ids[(
(local_transferred_tokens - remote_cache_tokens) //
self.block_size):((local_computed_tokens -
remote_cache_tokens) //
self.block_size)]
request.kv_transfer_params[
"remote_block_ids"] = remote_trans_block_ids
assert len(local_trans_block_ids) == len(
remote_trans_block_ids
), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}"
adjusted_tokens = local_computed_tokens - (
self.block_size -
1) if chunk_finish else local_computed_tokens
logger.info(
f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}"
)
meta.add_new_req(
request_id=req_id,
local_block_ids=local_trans_block_ids,
kv_transfer_params=request.kv_transfer_params,
token_ids=[],
chunk_finish=chunk_finish)
# update local_transferred_tokens
local_transferred_tokens = (
local_computed_tokens //
self.block_size) * self.block_size
send_req_info.update_transferred_tokens(
local_transferred_tokens)
# no chunk or last chunk
if send_req_info.local_computed_tokens >= len(
send_req_info.request.all_token_ids):
send_req_info.update_computed_tokens(
send_req_info.local_computed_tokens +
self.block_size - 1)
add_tranfer_task(req_id,
send_req_info,
chunk_finish=True)
self._reqs_need_send_layerwise.pop(req_id)
# chunk
elif (send_req_info.local_computed_tokens //
self.block_size) - (
send_req_info.local_transferred_tokens //
self.block_size) > 0:
add_tranfer_task(req_id, send_req_info)
return meta
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def request_finished(
self,
request: "Request",
@@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker:
self.total_layers = vllm_config.model_config.get_num_layers(
vllm_config.parallel_config)
self.executor = ThreadPoolExecutor(32)
self.metaserver_client = httpx.Client(
limits=httpx.Limits(max_connections=100000),
timeout=None) if self.tp_rank == 0 else None
# Handshake base port
self.side_channel_port = (
vllm_config.kv_transfer_config.kv_port +
@@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker:
self.kv_recv_layer_thread.start()
ready_event.wait()
def _access_metaserver(self, url, message):
success = False
retry = 0
while retry < 3 and success is False:
retry += 1
try:
self.metaserver_client.post(url, json=message)
success = True
except Exception as e:
logger.error(
f"Failed to connect to metaserver: {url}, retry {retry} time."
)
if retry == 3:
raise e
def get_finished(self) -> tuple[set[str], set[str]]:
done_recving = (
self.kv_recv_layer_thread.
@@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker:
self.current_layer = 0
if self.vllm_config.kv_transfer_config.is_kv_consumer:
for req_id, meta in metadata.requests.items():
if self.tp_rank % self.tp_size == 0:
logger.info(
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
)
# All parameters here should appear in the returned dict of
# request_finished in the scheduler side except "request_id".
kv_transfer_params = dict(
token_ids=meta.token_ids,
request_id=req_id,
do_remote_prefill=False,
do_remote_decode=True,
remote_block_ids=meta.local_block_ids,
remote_engine_id=self.engine_id,
remote_host=self.side_channel_host,
remote_port=self.side_channel_port,
)
future = self.executor.submit(
self._access_metaserver,
url=meta.metaserver,
message=kv_transfer_params,
)
def handle_exception(future):
if future.exception():
logger.error(
f"Access metaserver fail: {future.exception()}"
)
future.add_done_callback(handle_exception)
assert self.kv_recv_layer_thread is not None
with self.kv_recv_layer_thread.lock:
self.kv_recv_layer_thread.task_tracker[req_id] = 0
@@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker:
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
):
# enable decode prefix cache
for request in connector_metadata.requests.values():
assert len(request.local_block_ids) >= len(
request.remote_block_ids
), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num."
request.local_block_ids = request.local_block_ids[
-len(request.remote_block_ids):]
if self.use_mla:
reshape_cache_event = attn_metadata[
layer_name].reshape_cache_event
else:
reshape_cache_event = attn_metadata.reshape_cache_event
if self.pd_head_ratio != 1:
def sort_kv_cache(input_kv: list[list[int]]):
@@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker:
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
)
assert self.kv_send_layer_thread is not None
assert reshape_cache_event is not None
self.kv_send_layer_thread.send_queue.put(
(req_id, req_meta_update, self.current_layer, key, value))
(req_id, req_meta_update, self.current_layer, key, value,
reshape_cache_event))
self.current_layer += 1
def _get_remote_socket(