[P/D] Improve the performance of Layerwise Connector (#5303)
### What this PR does / why we need it?
Improve the performance of Layerwise Connector, mainly includes the
following points:
1. Use event synchronize to replace stream synchronize.
2. Access metaserver when scheduling.
3. Transfer kvcache each Chunk prefill segmentation.
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -65,6 +65,32 @@ class ReqMeta:
|
||||
remote_te_rpc_port: Optional[int]
|
||||
remote_kv_caches_base_addr: Optional[list[int]]
|
||||
metaserver: Optional[str]
|
||||
chunk_finish: Optional[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqInfo:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: List[int]
|
||||
remote_cache_tokens: int
|
||||
local_transferred_tokens: int
|
||||
local_computed_tokens: int
|
||||
request: "Request"
|
||||
|
||||
def extend_local_block_ids(self, new_block_ids: List[int]) -> None:
|
||||
"""extend local block ids for this step"""
|
||||
self.local_block_ids.extend(new_block_ids)
|
||||
|
||||
def update_computed_tokens(self, computed_tokens: int) -> None:
|
||||
"""update local computen tokens for this step"""
|
||||
self.local_computed_tokens = computed_tokens
|
||||
|
||||
def update_transferred_tokens(self, transferred_tokens: int) -> None:
|
||||
"""update transferred tokens for this step"""
|
||||
self.local_transferred_tokens = transferred_tokens
|
||||
|
||||
def unpack(self):
|
||||
return self.local_block_ids, self.remote_block_ids, self.remote_cache_tokens, self.local_transferred_tokens, self.local_computed_tokens, self.request
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -144,7 +170,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
|
||||
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
|
||||
torch.Tensor]]()
|
||||
torch.Tensor, torch.npu.Event]]()
|
||||
|
||||
self.ready_event = ready_event
|
||||
self.callback_func = callback_func
|
||||
@@ -155,15 +181,19 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
torch.npu.set_device(device)
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
req_id, req_meta, layer_index, key, value = self.send_queue.get()
|
||||
self._handle_request(req_id, req_meta, layer_index, key, value)
|
||||
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
|
||||
)
|
||||
self._handle_request(req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event)
|
||||
|
||||
def _handle_request(self, req_id, req_meta, layer_index, key, value):
|
||||
def _handle_request(self, req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event):
|
||||
try:
|
||||
logger.debug(
|
||||
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||
)
|
||||
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value)
|
||||
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event)
|
||||
logger.debug(
|
||||
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||
)
|
||||
@@ -171,13 +201,8 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
logger.error("Failed to transfer KV cache for request "
|
||||
f"{req_id}: {e}")
|
||||
|
||||
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value):
|
||||
# send kv layer to remote
|
||||
if len(req_meta.local_block_ids) == 0:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: No local blocks to transfer."
|
||||
)
|
||||
return
|
||||
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event):
|
||||
# not need to send kv cache
|
||||
if self.tp_rank % self.num_head_replica != 0:
|
||||
logger.debug(
|
||||
@@ -227,7 +252,13 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
length_list.append(length)
|
||||
if self.current_layer != layer_index:
|
||||
self.current_layer = layer_index
|
||||
self.model_stream.synchronize()
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
reshape_cache_event.synchronize()
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, src_list, dst_list, length_list)
|
||||
if ret < 0:
|
||||
@@ -285,7 +316,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
logger.error("Mooncake transfer failed for request %s", req_id)
|
||||
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
||||
|
||||
if layer_index == (self.total_layers - 1):
|
||||
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
|
||||
self.callback_func(req_id, req_meta)
|
||||
|
||||
|
||||
@@ -376,7 +407,8 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
request_id: str,
|
||||
local_block_ids: list[int],
|
||||
kv_transfer_params: dict[str, Any],
|
||||
token_ids: Optional[list[int]] = None):
|
||||
token_ids: Optional[list[int]] = None,
|
||||
chunk_finish: bool = False):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
token_ids=token_ids or [],
|
||||
local_block_ids=local_block_ids,
|
||||
@@ -389,7 +421,7 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
remote_kv_caches_base_addr=kv_transfer_params.get(
|
||||
"remote_kv_caches_base_addr", None),
|
||||
metaserver=kv_transfer_params.get("metaserver", None),
|
||||
)
|
||||
chunk_finish=chunk_finish)
|
||||
|
||||
|
||||
class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
||||
@@ -398,6 +430,7 @@ class MooncakeLayerwiseConnector(KVConnectorBase_V1):
|
||||
vllm_config: VllmConfig,
|
||||
role: KVConnectorRole,
|
||||
kv_cache_config: Optional[KVCacheConfig] = None):
|
||||
super().__init__(vllm_config, role, kv_cache_config)
|
||||
assert vllm_config.kv_transfer_config is not None
|
||||
self.engine_id = vllm_config.kv_transfer_config.engine_id
|
||||
self._connector_metadata = MooncakeLayerwiseConnectorMetadata()
|
||||
@@ -509,9 +542,11 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
# the scheduler. Used to make metadata passed to Worker.
|
||||
self._reqs_need_recv: dict[str, tuple[Request, list[int],
|
||||
list[int]]] = {}
|
||||
self._reqs_need_send_layerwise: dict[str, tuple[
|
||||
int, list[int],
|
||||
Request]] = {} # req_id, (len(prompt), local_block_ids, request)
|
||||
self._reqs_need_send_layerwise: dict[str, SendReqInfo] = {}
|
||||
|
||||
self.executor = ThreadPoolExecutor(32)
|
||||
self.metaserver_client = httpx.Client(
|
||||
limits=httpx.Limits(max_connections=100000), timeout=None)
|
||||
|
||||
def get_num_new_matched_tokens(
|
||||
self, request: "Request",
|
||||
@@ -571,14 +606,53 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
|
||||
params["do_remote_prefill"] = False
|
||||
|
||||
logger.info(
|
||||
f"Send request: {request.request_id} to proxy metaserver: {params.get('metaserver', None)}"
|
||||
)
|
||||
# All parameters here should appear in the returned dict of
|
||||
# request_finished in the scheduler side except "request_id".
|
||||
kv_transfer_params = dict(
|
||||
token_ids=[],
|
||||
request_id=request.request_id,
|
||||
do_remote_prefill=False,
|
||||
do_remote_decode=True,
|
||||
remote_block_ids=local_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
future = self.executor.submit(
|
||||
self._access_metaserver,
|
||||
url=params.get("metaserver", None),
|
||||
message=kv_transfer_params,
|
||||
)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(
|
||||
f"Access metaserver fail: {future.exception()}")
|
||||
|
||||
future.add_done_callback(handle_exception)
|
||||
|
||||
# Layerwise prefiller add request need send
|
||||
if params is not None and params.get("do_remote_decode"):
|
||||
local_block_ids = (blocks.get_block_ids()[0])
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need send queue"
|
||||
)
|
||||
self._reqs_need_send_layerwise[request.request_id] = (len(
|
||||
request.all_token_ids), local_block_ids, request)
|
||||
remote_block_ids = copy.deepcopy(params["remote_block_ids"])
|
||||
remote_cache_tokens = (
|
||||
(len(request.all_token_ids) + self.block_size - 1) //
|
||||
self.block_size - len(remote_block_ids)) * self.block_size
|
||||
local_transferred_tokens = remote_cache_tokens
|
||||
local_computed_tokens = 0
|
||||
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=remote_block_ids,
|
||||
remote_cache_tokens=remote_cache_tokens,
|
||||
local_transferred_tokens=local_transferred_tokens,
|
||||
local_computed_tokens=local_computed_tokens,
|
||||
request=request)
|
||||
|
||||
def build_connector_meta(
|
||||
self,
|
||||
@@ -586,55 +660,118 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
) -> KVConnectorMetadata:
|
||||
meta = MooncakeLayerwiseConnectorMetadata()
|
||||
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, token_ids,
|
||||
block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
# For the case where there are no remote blocks to pull
|
||||
# (block_ids is empty), we don't need to schedule
|
||||
# an async read on the worker side.
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
token_ids=token_ids)
|
||||
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
# Loop through scheduled reqs and convert to ReqMeta.
|
||||
for req_id, (req, token_ids,
|
||||
block_ids) in self._reqs_need_recv.items():
|
||||
assert req.kv_transfer_params is not None
|
||||
# For the case where there are no remote blocks to pull
|
||||
# (block_ids is empty), we don't need to schedule
|
||||
# an async read on the worker side.
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
token_ids=token_ids)
|
||||
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
# Clear the list once workers start the transfers
|
||||
self._reqs_need_recv.clear()
|
||||
else:
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
new_reqs = scheduler_output.scheduled_new_reqs
|
||||
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
# update local block ids
|
||||
for req_id, new_blocks in zip(cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids):
|
||||
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
|
||||
self._reqs_need_send_layerwise[
|
||||
req_id].extend_local_block_ids(new_blocks[0])
|
||||
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
new_reqs = scheduler_output.scheduled_new_reqs
|
||||
for req_id, new_blocks in zip(cached_reqs.req_ids,
|
||||
cached_reqs.new_block_ids):
|
||||
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
|
||||
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
|
||||
req_id]
|
||||
block_ids.extend(new_blocks[0])
|
||||
computed_tokens = dict(
|
||||
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
|
||||
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs])
|
||||
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
|
||||
):
|
||||
if req_id in self._reqs_need_send_layerwise:
|
||||
send_req_info = self._reqs_need_send_layerwise[req_id]
|
||||
# update local computed tokens, not transfer spec decode tokens
|
||||
spec_decode_tokens = len(
|
||||
scheduled_spec_decode_tokens[req_id]) if (
|
||||
req_id in scheduled_spec_decode_tokens) else 0
|
||||
send_req_info.update_computed_tokens(
|
||||
computed_tokens.get(req_id, 0) + scheduled_tokens -
|
||||
spec_decode_tokens)
|
||||
|
||||
computed_tokens = dict(
|
||||
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens)) +
|
||||
[(x.req_id, x.num_computed_tokens) for x in new_reqs])
|
||||
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items(
|
||||
):
|
||||
if req_id in self._reqs_need_send_layerwise:
|
||||
total_tokens, block_ids, req = self._reqs_need_send_layerwise[
|
||||
req_id]
|
||||
current_tokens = computed_tokens.get(req_id,
|
||||
0) + scheduled_tokens
|
||||
if current_tokens >= total_tokens:
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector build_connector_meta: add {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
||||
)
|
||||
meta.add_new_req(request_id=req_id,
|
||||
local_block_ids=block_ids,
|
||||
kv_transfer_params=req.kv_transfer_params,
|
||||
token_ids=[])
|
||||
self._reqs_need_send_layerwise.pop(req_id)
|
||||
else:
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector build_connector_meta: skip {req_id}, current tokens({current_tokens}={computed_tokens.get(req_id,0)}+{scheduled_tokens}), total tokens({total_tokens})"
|
||||
)
|
||||
def add_tranfer_task(req_id,
|
||||
send_req_info: SendReqInfo,
|
||||
chunk_finish=False):
|
||||
local_block_ids, remote_block_ids, remote_cache_tokens, local_transferred_tokens, local_computed_tokens, request = send_req_info.unpack(
|
||||
)
|
||||
local_trans_block_ids = local_block_ids[(
|
||||
local_transferred_tokens //
|
||||
self.block_size):(local_computed_tokens //
|
||||
self.block_size)]
|
||||
remote_trans_block_ids = remote_block_ids[(
|
||||
(local_transferred_tokens - remote_cache_tokens) //
|
||||
self.block_size):((local_computed_tokens -
|
||||
remote_cache_tokens) //
|
||||
self.block_size)]
|
||||
request.kv_transfer_params[
|
||||
"remote_block_ids"] = remote_trans_block_ids
|
||||
assert len(local_trans_block_ids) == len(
|
||||
remote_trans_block_ids
|
||||
), f"len of local trans block ids : {len(local_trans_block_ids)} not equal to the len of remote trans block ids : {len(remote_trans_block_ids)}"
|
||||
adjusted_tokens = local_computed_tokens - (
|
||||
self.block_size -
|
||||
1) if chunk_finish else local_computed_tokens
|
||||
logger.info(
|
||||
f"MooncakeLayerwiseConnector scheduler add transfer task: {req_id=} {local_block_ids=} {remote_block_ids=} {local_trans_block_ids=} {remote_trans_block_ids=} local_computed_tokens={adjusted_tokens} request.all_token_ids={len(request.all_token_ids)}"
|
||||
)
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=local_trans_block_ids,
|
||||
kv_transfer_params=request.kv_transfer_params,
|
||||
token_ids=[],
|
||||
chunk_finish=chunk_finish)
|
||||
# update local_transferred_tokens
|
||||
local_transferred_tokens = (
|
||||
local_computed_tokens //
|
||||
self.block_size) * self.block_size
|
||||
send_req_info.update_transferred_tokens(
|
||||
local_transferred_tokens)
|
||||
|
||||
# no chunk or last chunk
|
||||
if send_req_info.local_computed_tokens >= len(
|
||||
send_req_info.request.all_token_ids):
|
||||
send_req_info.update_computed_tokens(
|
||||
send_req_info.local_computed_tokens +
|
||||
self.block_size - 1)
|
||||
add_tranfer_task(req_id,
|
||||
send_req_info,
|
||||
chunk_finish=True)
|
||||
self._reqs_need_send_layerwise.pop(req_id)
|
||||
# chunk
|
||||
elif (send_req_info.local_computed_tokens //
|
||||
self.block_size) - (
|
||||
send_req_info.local_transferred_tokens //
|
||||
self.block_size) > 0:
|
||||
add_tranfer_task(req_id, send_req_info)
|
||||
return meta
|
||||
|
||||
def _access_metaserver(self, url, message):
|
||||
success = False
|
||||
retry = 0
|
||||
while retry < 3 and success is False:
|
||||
retry += 1
|
||||
try:
|
||||
self.metaserver_client.post(url, json=message)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to connect to metaserver: {url}, retry {retry} time."
|
||||
)
|
||||
if retry == 3:
|
||||
raise e
|
||||
|
||||
def request_finished(
|
||||
self,
|
||||
request: "Request",
|
||||
@@ -676,11 +813,6 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.total_layers = vllm_config.model_config.get_num_layers(
|
||||
vllm_config.parallel_config)
|
||||
|
||||
self.executor = ThreadPoolExecutor(32)
|
||||
self.metaserver_client = httpx.Client(
|
||||
limits=httpx.Limits(max_connections=100000),
|
||||
timeout=None) if self.tp_rank == 0 else None
|
||||
|
||||
# Handshake base port
|
||||
self.side_channel_port = (
|
||||
vllm_config.kv_transfer_config.kv_port +
|
||||
@@ -834,21 +966,6 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.kv_recv_layer_thread.start()
|
||||
ready_event.wait()
|
||||
|
||||
def _access_metaserver(self, url, message):
|
||||
success = False
|
||||
retry = 0
|
||||
while retry < 3 and success is False:
|
||||
retry += 1
|
||||
try:
|
||||
self.metaserver_client.post(url, json=message)
|
||||
success = True
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to connect to metaserver: {url}, retry {retry} time."
|
||||
)
|
||||
if retry == 3:
|
||||
raise e
|
||||
|
||||
def get_finished(self) -> tuple[set[str], set[str]]:
|
||||
done_recving = (
|
||||
self.kv_recv_layer_thread.
|
||||
@@ -865,35 +982,6 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.current_layer = 0
|
||||
if self.vllm_config.kv_transfer_config.is_kv_consumer:
|
||||
for req_id, meta in metadata.requests.items():
|
||||
if self.tp_rank % self.tp_size == 0:
|
||||
logger.info(
|
||||
f"Send request: {req_id} to proxy metaserver: {meta.metaserver}"
|
||||
)
|
||||
# All parameters here should appear in the returned dict of
|
||||
# request_finished in the scheduler side except "request_id".
|
||||
kv_transfer_params = dict(
|
||||
token_ids=meta.token_ids,
|
||||
request_id=req_id,
|
||||
do_remote_prefill=False,
|
||||
do_remote_decode=True,
|
||||
remote_block_ids=meta.local_block_ids,
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
)
|
||||
future = self.executor.submit(
|
||||
self._access_metaserver,
|
||||
url=meta.metaserver,
|
||||
message=kv_transfer_params,
|
||||
)
|
||||
|
||||
def handle_exception(future):
|
||||
if future.exception():
|
||||
logger.error(
|
||||
f"Access metaserver fail: {future.exception()}"
|
||||
)
|
||||
|
||||
future.add_done_callback(handle_exception)
|
||||
assert self.kv_recv_layer_thread is not None
|
||||
with self.kv_recv_layer_thread.lock:
|
||||
self.kv_recv_layer_thread.task_tracker[req_id] = 0
|
||||
@@ -907,12 +995,12 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
||||
):
|
||||
# enable decode prefix cache
|
||||
for request in connector_metadata.requests.values():
|
||||
assert len(request.local_block_ids) >= len(
|
||||
request.remote_block_ids
|
||||
), "When prefix cache enabled, remote KVCacheBlocks num should not larger than local KVCacheBlocks num."
|
||||
request.local_block_ids = request.local_block_ids[
|
||||
-len(request.remote_block_ids):]
|
||||
if self.use_mla:
|
||||
reshape_cache_event = attn_metadata[
|
||||
layer_name].reshape_cache_event
|
||||
else:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
|
||||
if self.pd_head_ratio != 1:
|
||||
|
||||
def sort_kv_cache(input_kv: list[list[int]]):
|
||||
@@ -964,8 +1052,10 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
self.kv_send_layer_thread.send_queue.put(
|
||||
(req_id, req_meta_update, self.current_layer, key, value))
|
||||
(req_id, req_meta_update, self.current_layer, key, value,
|
||||
reshape_cache_event))
|
||||
self.current_layer += 1
|
||||
|
||||
def _get_remote_socket(
|
||||
|
||||
Reference in New Issue
Block a user