[P/D][PCP] mooncake layerwise support pcp function (#6627)
### What this PR does / why we need it?
mooncake layerwise support pcp function
PCP (Prefill Context Parallelism) Support: Introduced explicit support
for Prefill Context Parallelism (PCP) and Decode Context Parallelism
(DCP) in the Mooncake layerwise KV cache transfer mechanism, allowing
for more granular control and awareness of parallel configurations
during data transfer.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
By ci
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -23,9 +23,16 @@ import torch_npu
|
||||
import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm.config import VllmConfig
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole
|
||||
from vllm.distributed.parallel_state import get_tensor_model_parallel_rank, get_tp_group, get_world_group
|
||||
from vllm.distributed.parallel_state import (
|
||||
get_decode_context_model_parallel_rank,
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tp_group,
|
||||
get_world_group,
|
||||
)
|
||||
from vllm.logger import logger
|
||||
from vllm.utils.math_utils import round_down
|
||||
from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import KVCacheConfig
|
||||
@@ -35,8 +42,13 @@ from vllm_ascend.distributed.kv_transfer.kv_p2p.mooncake_connector import GET_ME
|
||||
from vllm_ascend.distributed.kv_transfer.utils.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.kv_transfer.utils.utils import (
|
||||
align_memory,
|
||||
context_parallel_parameters_check,
|
||||
get_cp_group,
|
||||
get_local_remote_block_port_mappings,
|
||||
get_transfer_mappings,
|
||||
get_transfer_timeout_value,
|
||||
kv_alltoall_and_rearrange,
|
||||
parallel_info,
|
||||
)
|
||||
from vllm_ascend.utils import npu_stream_switch
|
||||
|
||||
@@ -68,7 +80,15 @@ class ReqMeta:
|
||||
remote_te_rpc_port: int | None
|
||||
remote_kv_caches_base_addr: list[int] | None
|
||||
metaserver: str | None
|
||||
chunk_finish: bool | None
|
||||
remote_tp_size: int | None
|
||||
remote_pcp_size: int | None
|
||||
remote_dcp_size: int | None
|
||||
chunk_finish: bool = False
|
||||
prompt_len: int = 0
|
||||
trans_count: int = 0
|
||||
remote_cache_tokens: int = 0
|
||||
local_computed_tokens: int = 0
|
||||
local_transed_tokens: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -100,8 +120,6 @@ class TransferMeta:
|
||||
@dataclass
|
||||
class SendReqInfo:
|
||||
local_block_ids: list[int]
|
||||
remote_block_ids: list[int]
|
||||
remote_cache_tokens: int
|
||||
local_transferred_tokens: int
|
||||
local_computed_tokens: int
|
||||
request: "Request"
|
||||
@@ -121,8 +139,6 @@ class SendReqInfo:
|
||||
def unpack(self):
|
||||
return (
|
||||
self.local_block_ids,
|
||||
self.remote_block_ids,
|
||||
self.remote_cache_tokens,
|
||||
self.local_transferred_tokens,
|
||||
self.local_computed_tokens,
|
||||
self.request,
|
||||
@@ -161,8 +177,6 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
kv_cache_base_addr: list[int],
|
||||
use_mla: bool,
|
||||
block_len: list[int],
|
||||
decode_tp_size: int,
|
||||
first_kv_cache: torch.Tensor,
|
||||
k_buffer: torch.Tensor,
|
||||
v_buffer: torch.Tensor,
|
||||
resharding_stream: torch.npu.Stream,
|
||||
@@ -178,7 +192,6 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.use_mla = use_mla
|
||||
self.use_sparse = len(block_len) == 3
|
||||
self.block_len = block_len
|
||||
self._decode_tp_size = decode_tp_size
|
||||
self.resharding_stream = resharding_stream
|
||||
self.current_layer = -1
|
||||
|
||||
@@ -373,10 +386,10 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
self.done_requests = set()
|
||||
return finished_requests
|
||||
|
||||
def update_task(self, req_id):
|
||||
def update_task(self, req_id, trans_count):
|
||||
with self.lock:
|
||||
self.task_tracker[req_id] += 1
|
||||
if self.task_tracker[req_id] == self.pd_head_ratio:
|
||||
if self.task_tracker[req_id] == trans_count:
|
||||
self.task_tracker.pop(req_id)
|
||||
self.done_requests.add(self.request_map[req_id])
|
||||
self.request_map.pop(req_id)
|
||||
@@ -411,7 +424,8 @@ class KVCacheRecvingLayerThread(threading.Thread):
|
||||
elif msg[0] == DONE_SENDING_MSG:
|
||||
logger.debug("Got DONE_RECVING_MSG for request %s", msg[1])
|
||||
request_id = msg[1]
|
||||
self.update_task(request_id)
|
||||
trans_count = msg[2]
|
||||
self.update_task(request_id, trans_count)
|
||||
sock.send_multipart((identity, b"", b"ACK"))
|
||||
else:
|
||||
logger.error("Connection listener got unexpected message %s", msg)
|
||||
@@ -431,6 +445,10 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
kv_transfer_params: dict[str, Any],
|
||||
token_ids: list[int] | None = None,
|
||||
chunk_finish: bool = False,
|
||||
prompt_len: int = 0,
|
||||
remote_cache_tokens: int = 0,
|
||||
local_computed_tokens: int = 0,
|
||||
local_transed_tokens: int = 0,
|
||||
):
|
||||
self.requests[request_id] = ReqMeta(
|
||||
token_ids=token_ids or [],
|
||||
@@ -442,7 +460,14 @@ class MooncakeLayerwiseConnectorMetadata(KVConnectorMetadata):
|
||||
remote_te_rpc_port=kv_transfer_params.get("remote_te_rpc_port"),
|
||||
remote_kv_caches_base_addr=kv_transfer_params.get("remote_kv_caches_base_addr"),
|
||||
metaserver=kv_transfer_params.get("metaserver"),
|
||||
remote_tp_size=kv_transfer_params.get("remote_tp_size"),
|
||||
remote_pcp_size=kv_transfer_params.get("remote_pcp_size"),
|
||||
remote_dcp_size=kv_transfer_params.get("remote_dcp_size"),
|
||||
chunk_finish=chunk_finish,
|
||||
remote_cache_tokens=remote_cache_tokens,
|
||||
local_computed_tokens=local_computed_tokens,
|
||||
prompt_len=prompt_len,
|
||||
local_transed_tokens=local_transed_tokens,
|
||||
)
|
||||
|
||||
|
||||
@@ -605,7 +630,8 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
)
|
||||
|
||||
if params is not None and params.get("do_remote_prefill"):
|
||||
local_block_ids = blocks.get_unhashed_block_ids() if num_external_tokens > 0 else []
|
||||
local_block_ids = (blocks.get_block_ids()[0]) if num_external_tokens > 0 else []
|
||||
remote_cached_tokens = request.num_computed_tokens
|
||||
# Get unhashed blocks to pull from remote.
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector update_state_after_alloc: add {request.request_id} to need recv queue"
|
||||
@@ -632,6 +658,10 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
remote_engine_id=self.engine_id,
|
||||
remote_host=self.side_channel_host,
|
||||
remote_port=self.side_channel_port,
|
||||
remote_tp_size=self.vllm_config.parallel_config.tensor_parallel_size,
|
||||
remote_pcp_size=self.vllm_config.parallel_config.prefill_context_parallel_size,
|
||||
remote_dcp_size=self.vllm_config.parallel_config.decode_context_parallel_size,
|
||||
remote_cached_tokens=remote_cached_tokens,
|
||||
)
|
||||
|
||||
future = self.executor.submit(
|
||||
@@ -658,8 +688,6 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
local_computed_tokens = 0
|
||||
self._reqs_need_send_layerwise[request.request_id] = SendReqInfo(
|
||||
local_block_ids=local_block_ids,
|
||||
remote_block_ids=remote_block_ids,
|
||||
remote_cache_tokens=remote_cache_tokens,
|
||||
local_transferred_tokens=local_transferred_tokens,
|
||||
local_computed_tokens=local_computed_tokens,
|
||||
request=request,
|
||||
@@ -691,11 +719,9 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
cached_reqs = scheduler_output.scheduled_cached_reqs
|
||||
new_reqs = scheduler_output.scheduled_new_reqs
|
||||
scheduled_spec_decode_tokens = scheduler_output.scheduled_spec_decode_tokens
|
||||
# update local block ids
|
||||
for req_id, new_blocks in zip(cached_reqs.req_ids, cached_reqs.new_block_ids):
|
||||
if req_id in self._reqs_need_send_layerwise and new_blocks is not None:
|
||||
self._reqs_need_send_layerwise[req_id].extend_local_block_ids(new_blocks[0])
|
||||
|
||||
computed_tokens = dict(
|
||||
list(zip(cached_reqs.req_ids, cached_reqs.num_computed_tokens))
|
||||
+ [(x.req_id, x.num_computed_tokens) for x in new_reqs]
|
||||
@@ -703,6 +729,10 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
for req_id, scheduled_tokens in scheduler_output.num_scheduled_tokens.items():
|
||||
if req_id in self._reqs_need_send_layerwise:
|
||||
send_req_info = self._reqs_need_send_layerwise[req_id]
|
||||
# update local transferred tokens
|
||||
send_req_info.update_transferred_tokens(
|
||||
round_down(send_req_info.local_computed_tokens, self.block_size)
|
||||
)
|
||||
# update local computed tokens, not transfer spec decode tokens
|
||||
spec_decode_tokens = (
|
||||
len(scheduled_spec_decode_tokens[req_id]) if (req_id in scheduled_spec_decode_tokens) else 0
|
||||
@@ -714,56 +744,36 @@ class MooncakeLayerwiseConnectorScheduler:
|
||||
def add_tranfer_task(req_id, send_req_info: SendReqInfo, chunk_finish=False):
|
||||
(
|
||||
local_block_ids,
|
||||
remote_block_ids,
|
||||
remote_cache_tokens,
|
||||
local_transferred_tokens,
|
||||
local_transed_tokens,
|
||||
local_computed_tokens,
|
||||
request,
|
||||
) = send_req_info.unpack()
|
||||
local_trans_block_ids = local_block_ids[
|
||||
(local_transferred_tokens // self.block_size) : (local_computed_tokens // self.block_size)
|
||||
]
|
||||
remote_trans_block_ids = remote_block_ids[
|
||||
((local_transferred_tokens - remote_cache_tokens) // self.block_size) : (
|
||||
(local_computed_tokens - remote_cache_tokens) // self.block_size
|
||||
)
|
||||
]
|
||||
request.kv_transfer_params["remote_block_ids"] = remote_trans_block_ids
|
||||
assert len(local_trans_block_ids) == len(remote_trans_block_ids), (
|
||||
f"len of local trans block ids : {len(local_trans_block_ids)} not equal to "
|
||||
f"the len of remote trans block ids : {len(remote_trans_block_ids)}"
|
||||
)
|
||||
adjusted_tokens = (
|
||||
local_computed_tokens - (self.block_size - 1) if chunk_finish else local_computed_tokens
|
||||
)
|
||||
logger.info(
|
||||
f"MooncakeLayerwiseConnector scheduler add transfer task: "
|
||||
f"{req_id=} {local_block_ids=} {remote_block_ids=} "
|
||||
f"{local_trans_block_ids=} {remote_trans_block_ids=} "
|
||||
f"local_computed_tokens={adjusted_tokens} "
|
||||
f"request.all_token_ids={len(request.all_token_ids)}"
|
||||
)
|
||||
meta.add_new_req(
|
||||
request_id=req_id,
|
||||
local_block_ids=local_trans_block_ids,
|
||||
local_block_ids=local_block_ids,
|
||||
kv_transfer_params=request.kv_transfer_params,
|
||||
token_ids=[],
|
||||
chunk_finish=chunk_finish,
|
||||
remote_cache_tokens=request.kv_transfer_params.get("remote_cached_tokens"),
|
||||
prompt_len=len(request.all_token_ids),
|
||||
local_computed_tokens=local_computed_tokens,
|
||||
local_transed_tokens=local_transed_tokens,
|
||||
)
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector build_connector_meta: {req_id=}"
|
||||
f"prompt_len={len(request.all_token_ids)} {local_computed_tokens=}"
|
||||
f"{local_transed_tokens=}"
|
||||
f"remote_cache_tokens={request.kv_transfer_params.get('remote_cached_tokens')}"
|
||||
f"{chunk_finish=} {local_block_ids=}"
|
||||
f"remote_block_ids={request.kv_transfer_params.get('remote_block_ids')}"
|
||||
)
|
||||
# update local_transferred_tokens
|
||||
local_transferred_tokens = (local_computed_tokens // self.block_size) * self.block_size
|
||||
send_req_info.update_transferred_tokens(local_transferred_tokens)
|
||||
|
||||
# no chunk or last chunk
|
||||
if send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids):
|
||||
send_req_info.update_computed_tokens(send_req_info.local_computed_tokens + self.block_size - 1)
|
||||
add_tranfer_task(req_id, send_req_info, chunk_finish=True)
|
||||
# whether chunk finish
|
||||
chunk_finish = send_req_info.local_computed_tokens >= len(send_req_info.request.all_token_ids)
|
||||
|
||||
add_tranfer_task(req_id, send_req_info, chunk_finish=chunk_finish)
|
||||
if chunk_finish:
|
||||
self._reqs_need_send_layerwise.pop(req_id)
|
||||
# chunk
|
||||
elif (send_req_info.local_computed_tokens // self.block_size) - (
|
||||
send_req_info.local_transferred_tokens // self.block_size
|
||||
) > 0:
|
||||
add_tranfer_task(req_id, send_req_info)
|
||||
return meta
|
||||
|
||||
def _access_metaserver(self, url, message):
|
||||
@@ -796,13 +806,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
"""Implementation of Worker side methods"""
|
||||
|
||||
def __init__(self, vllm_config: VllmConfig, engine_id: str):
|
||||
self._get_prefill_decode_size(vllm_config)
|
||||
os.environ["ASCEND_TRANSFER_TIMEOUT"] = str(get_transfer_timeout_value())
|
||||
if self._prefill_tp_size < self._decode_tp_size:
|
||||
raise ValueError(
|
||||
f"prefill_tp_size: {self._prefill_tp_size} must be greater than"
|
||||
f" or equal to the decode_tp_size: {self._decode_tp_size}"
|
||||
)
|
||||
|
||||
if TransferEngine is None:
|
||||
raise RuntimeError("mooncake is not available")
|
||||
@@ -814,11 +818,20 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.engine_id = engine_id
|
||||
self.tp_rank = get_tensor_model_parallel_rank()
|
||||
self.tp_size = vllm_config.parallel_config.tensor_parallel_size
|
||||
self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size
|
||||
self.pcp_rank = get_pcp_group().rank_in_group if self.pcp_size > 1 else 0
|
||||
self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size
|
||||
self.dcp_rank = get_decode_context_model_parallel_rank() if self.dcp_size > 1 else 0
|
||||
self.tp_group = get_tp_group()
|
||||
self._decode_tp_size: int | None = None
|
||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||
self.side_channel_host = get_ip()
|
||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||
self.use_mla = self.vllm_config.model_config.use_mla
|
||||
if self.use_mla:
|
||||
self.total_num_kv_heads = 1
|
||||
else:
|
||||
self.total_num_kv_heads = self.vllm_config.model_config.get_total_num_kv_heads()
|
||||
|
||||
# Handshake base port
|
||||
self.side_channel_port = (
|
||||
@@ -863,23 +876,6 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.k_buffer: torch.Tensor | None = None
|
||||
self.v_buffer: torch.Tensor | None = None
|
||||
|
||||
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
||||
# get prefill tp and dp size from extra config
|
||||
prefill_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("prefill", {})
|
||||
|
||||
assert "tp_size" in prefill_parallel_config
|
||||
self._prefill_tp_size = prefill_parallel_config["tp_size"]
|
||||
|
||||
assert "dp_size" in prefill_parallel_config
|
||||
self._prefill_dp_size = prefill_parallel_config["dp_size"]
|
||||
|
||||
# get decode tp and dp size from extra config
|
||||
decode_parallel_config: dict[str, Any] = vllm_config.kv_transfer_config.get_from_extra_config("decode", {})
|
||||
assert "tp_size" in decode_parallel_config
|
||||
self._decode_tp_size = decode_parallel_config["tp_size"]
|
||||
assert "dp_size" in decode_parallel_config
|
||||
self._decode_dp_size = decode_parallel_config["dp_size"]
|
||||
|
||||
def create_kv_buffer(self, first_kv_cache):
|
||||
if self.pd_head_ratio > 1:
|
||||
# regesit kv buffer for tp inequal
|
||||
@@ -977,8 +973,6 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
kv_cache_base_addr=self.kv_caches_base_addr,
|
||||
use_mla=self.use_mla,
|
||||
block_len=self.block_len,
|
||||
decode_tp_size=self._decode_tp_size,
|
||||
first_kv_cache=first_kv_cache,
|
||||
k_buffer=self.k_buffer,
|
||||
v_buffer=self.v_buffer,
|
||||
resharding_stream=self.resharding_stream,
|
||||
@@ -1009,9 +1003,120 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
else set()
|
||||
)
|
||||
if len(done_recving) > 0:
|
||||
logger.info("Number of completed KV cache recv requests: %d, receive requests: %d", 0, len(done_recving))
|
||||
logger.info(
|
||||
f"Number of completed KV cache recv requests: {len(done_recving)}, receive requests: {done_recving}"
|
||||
)
|
||||
return set(), done_recving
|
||||
|
||||
# {(ip, port)]: {local_block_ids: [], remote_block_ids: {}}}
|
||||
def _get_kv_split_metadata(self, req_meta, req_idx, req_id):
|
||||
remote_pcp_size = req_meta.remote_pcp_size
|
||||
remote_dcp_size = req_meta.remote_dcp_size
|
||||
remote_tp_size = req_meta.remote_tp_size
|
||||
remote_hosts = [req_meta.remote_host]
|
||||
remote_port = req_meta.remote_port
|
||||
local_transed_tokens = max(req_meta.remote_cache_tokens, req_meta.local_transed_tokens)
|
||||
# local_transed_tokens tokens that have already been transmitted on the local side
|
||||
local_computed_tokens = req_meta.local_computed_tokens
|
||||
prompt_len = req_meta.prompt_len
|
||||
p_parallel_info = parallel_info(
|
||||
tp_size=self.tp_size,
|
||||
pcp_size=self.pcp_size,
|
||||
dcp_size=self.dcp_size,
|
||||
pd_head_ratio=self.pd_head_ratio,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
d_parallel_info = parallel_info(
|
||||
tp_size=remote_tp_size,
|
||||
pcp_size=remote_pcp_size,
|
||||
dcp_size=remote_dcp_size,
|
||||
pd_head_ratio=self.pd_head_ratio,
|
||||
use_mla=self.use_mla,
|
||||
)
|
||||
cp_size = self.pcp_size * self.dcp_size
|
||||
# to_trans_idx all tokens that have been processed up to the current step
|
||||
if req_meta.chunk_finish:
|
||||
to_trans_idx = math.ceil(local_computed_tokens / self.block_size)
|
||||
else:
|
||||
to_trans_idx = math.floor(local_computed_tokens / self.block_size)
|
||||
prompt_block_size = math.ceil(prompt_len / self.block_size)
|
||||
#
|
||||
num_local_blocks = prompt_block_size // cp_size + int(
|
||||
(prompt_block_size % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank)
|
||||
)
|
||||
already_send_blocks = to_trans_idx // cp_size + int(
|
||||
(to_trans_idx % cp_size) > (self.pcp_rank * self.dcp_size + self.dcp_rank)
|
||||
)
|
||||
if num_local_blocks == already_send_blocks:
|
||||
req_meta.chunk_finish = True
|
||||
transed_idx = math.floor(local_transed_tokens / self.block_size)
|
||||
|
||||
p_cp_group = get_cp_group(self.tp_size, self.total_num_kv_heads, self.dcp_size)
|
||||
d_cp_group = get_cp_group(remote_tp_size, self.total_num_kv_heads, remote_dcp_size)
|
||||
logger.debug(f"Compute cp group for P&D {req_id=} {p_cp_group=} {d_cp_group=}")
|
||||
|
||||
cp_ratio = len(p_cp_group) // len(d_cp_group)
|
||||
if cp_ratio == 0:
|
||||
selected_p_cp_groups = p_cp_group
|
||||
selected_d_cp_groups = d_cp_group
|
||||
else:
|
||||
x = req_idx % cp_ratio
|
||||
start = x * len(d_cp_group)
|
||||
selected_p_cp_groups = p_cp_group[start : (start + len(d_cp_group))]
|
||||
selected_d_cp_groups = d_cp_group
|
||||
assert len(selected_p_cp_groups) == len(selected_d_cp_groups)
|
||||
|
||||
p_head_group_rank = (self.tp_rank - self.dcp_rank) // self.dcp_size
|
||||
selected_p_cp_group = []
|
||||
selected_d_cp_group = []
|
||||
for idx, cp_group in enumerate(selected_p_cp_groups):
|
||||
if p_head_group_rank in cp_group: # Check whether the rank is in selected_p_cp_groups
|
||||
selected_p_cp_group = cp_group
|
||||
selected_d_cp_group = selected_d_cp_groups[idx]
|
||||
if len(selected_p_cp_group) == 0:
|
||||
return {}
|
||||
|
||||
logger.debug(
|
||||
f"MooncakeLayerwiseConnector _get_kv_split_metadata {req_id=} "
|
||||
f"P-side selected head_group cp group: {selected_p_cp_group}, "
|
||||
f"D-side selected head_group cp group: {selected_d_cp_group}"
|
||||
)
|
||||
|
||||
context_parallel_parameters_check(
|
||||
remote_pcp_size, remote_dcp_size, p_parallel_info, d_parallel_info, self.total_num_kv_heads
|
||||
)
|
||||
p_rank_block_mapping, d_block_rank_mapping, pd_head_mapping, d_trans_count_mapping = (
|
||||
get_local_remote_block_port_mappings(
|
||||
to_trans_idx,
|
||||
p_parallel_info,
|
||||
d_parallel_info,
|
||||
remote_hosts,
|
||||
remote_port,
|
||||
selected_p_cp_group,
|
||||
selected_d_cp_group,
|
||||
prompt_len,
|
||||
self.block_size,
|
||||
req_meta,
|
||||
self.total_num_kv_heads,
|
||||
req_id,
|
||||
)
|
||||
)
|
||||
transfer_mappings = get_transfer_mappings(
|
||||
p_rank_block_mapping,
|
||||
d_block_rank_mapping,
|
||||
pd_head_mapping,
|
||||
d_trans_count_mapping,
|
||||
req_meta,
|
||||
p_parallel_info,
|
||||
req_id,
|
||||
transed_idx,
|
||||
to_trans_idx,
|
||||
self.tp_rank,
|
||||
self.pcp_rank,
|
||||
self.dcp_rank,
|
||||
)
|
||||
return transfer_mappings
|
||||
|
||||
def start_load_kv(self, metadata: MooncakeLayerwiseConnectorMetadata):
|
||||
"""Start loading KV blocks from remote engine."""
|
||||
self.current_layer = 0
|
||||
@@ -1023,31 +1128,29 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.kv_recv_layer_thread.task_tracker[external_req_id] = 0
|
||||
self.kv_recv_layer_thread.request_map[external_req_id] = req_id
|
||||
elif self.vllm_config.kv_transfer_config.is_kv_producer:
|
||||
# select req to send
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_need_send = self._decode_tp_size
|
||||
else:
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
if self.tp_size <= num_kv_head:
|
||||
num_need_send = self.tp_size
|
||||
else:
|
||||
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
||||
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
req_id for i, req_id in enumerate(req_ids) if i % num_replica_groups == replica_group_idx
|
||||
]
|
||||
request_ids = list(metadata.requests.keys())
|
||||
for req_id in request_ids:
|
||||
if req_id not in selected_req_ids:
|
||||
metadata.requests.pop(req_id)
|
||||
# update trans info
|
||||
update_metadata = {}
|
||||
for req_idx, (req_id, req_meta) in enumerate(metadata.requests.items()):
|
||||
self._decode_tp_size = req_meta.remote_tp_size
|
||||
transfer_mappings = self._get_kv_split_metadata(req_meta, req_idx, req_id)
|
||||
assert len(transfer_mappings) <= 1, f"Not support add mutil transfer task for req_id:{req_id}"
|
||||
update_req_meta = copy.deepcopy(req_meta)
|
||||
for (host, port), block_dict in transfer_mappings.items():
|
||||
update_req_meta.remote_host = host
|
||||
update_req_meta.remote_port = port
|
||||
update_req_meta.local_block_ids = block_dict["local_block_ids"]
|
||||
update_req_meta.remote_block_ids = block_dict["remote_block_ids"]
|
||||
update_req_meta.trans_count = block_dict["trans_count"]
|
||||
update_metadata[req_id] = update_req_meta
|
||||
metadata.requests = {}
|
||||
for req_id, req_meta in update_metadata.items():
|
||||
metadata.requests[req_id] = update_metadata[req_id]
|
||||
|
||||
# update send task trans block info
|
||||
if self.pd_head_ratio != 1:
|
||||
send_task = metadata.send_task
|
||||
send_task.rearrange_block_ids = sorted(
|
||||
{block_id for req_id in selected_req_ids for block_id in metadata.requests[req_id].local_block_ids}
|
||||
{block_id for req_id in metadata.requests for block_id in metadata.requests[req_id].local_block_ids}
|
||||
)
|
||||
|
||||
device = self.k_buffer.device # type: ignore
|
||||
@@ -1070,7 +1173,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
) -> None:
|
||||
"""MooncakeLayerwiseConnector does not save explicitly."""
|
||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys():
|
||||
# enable decode prefix cache
|
||||
# get reshape and cache event
|
||||
if self.use_mla or self.use_sparse:
|
||||
reshape_cache_event = attn_metadata[layer_name].reshape_cache_event
|
||||
else:
|
||||
@@ -1156,59 +1259,48 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
return sock
|
||||
|
||||
def update_decoder_info(self, req_id, req_meta):
|
||||
req_meta_update = copy.deepcopy(req_meta)
|
||||
if self.use_mla or self.use_sparse:
|
||||
pd_tp_ratio = self.tp_size // self._decode_tp_size
|
||||
req_meta_update.remote_port = (
|
||||
req_meta_update.remote_port + (self.tp_rank // pd_tp_ratio) % self._decode_tp_size
|
||||
)
|
||||
else:
|
||||
req_meta_update.remote_port = (
|
||||
req_meta_update.remote_port + (self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||
)
|
||||
if (
|
||||
req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr
|
||||
or req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]
|
||||
req_meta.remote_engine_id not in self.remote_kv_caches_base_addr
|
||||
or req_meta.remote_port not in self.remote_kv_caches_base_addr[req_meta.remote_engine_id]
|
||||
):
|
||||
try:
|
||||
encoded_data = self.encoder.encode((GET_META_MSG, req_id))
|
||||
sock = self._get_remote_socket(req_meta_update.remote_host, req_meta_update.remote_port)
|
||||
path = f"{req_meta_update.remote_host}:{req_meta_update.remote_port}"
|
||||
sock = self._get_remote_socket(req_meta.remote_host, req_meta.remote_port)
|
||||
path = f"{req_meta.remote_host}:{req_meta.remote_port}"
|
||||
ensure_zmq_send(sock, encoded_data, path)
|
||||
metadata_bytes = ensure_zmq_recv(sock, self.remote_poller, path)
|
||||
agent_meta = self.decoder.decode(metadata_bytes)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Query to port and kv base addr for request {req_id} from "
|
||||
f"{req_meta_update.remote_host}:{req_meta_update.remote_port} fail with error: {e}"
|
||||
f"Query to port and kv base addr for request {req_id}"
|
||||
f"from {req_meta.remote_host}:{req_meta.remote_port}"
|
||||
f"fail with error: {e}"
|
||||
)
|
||||
assert req_meta_update.remote_engine_id != self.engine_id, (
|
||||
f"Conflict engine id {req_meta_update.remote_engine_id} with local engine id {self.local_engine_id}."
|
||||
assert req_meta.remote_engine_id != self.engine_id, (
|
||||
f"Conflict engine id {req_meta.remote_engine_id} with local engine id {self.local_engine_id}."
|
||||
)
|
||||
self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][req_meta_update.remote_port] = (
|
||||
self.remote_kv_caches_base_addr[req_meta.remote_engine_id][req_meta.remote_port] = (
|
||||
agent_meta.kv_caches_base_addr
|
||||
)
|
||||
self.remote_te_port[req_meta_update.remote_engine_id][req_meta_update.remote_port] = agent_meta.te_rpc_port
|
||||
self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port] = agent_meta.te_rpc_port
|
||||
logger.info(
|
||||
f"Query to port and kv base addr for request {req_id} from "
|
||||
f"{req_meta_update.remote_host}:{req_meta_update.remote_port} success "
|
||||
f"{agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||
f"Query to port and kv base addr for request {req_id}"
|
||||
f"from {req_meta.remote_host}:{req_meta.remote_port}"
|
||||
f"success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||
)
|
||||
if self.pd_head_ratio > 1:
|
||||
# for tp inequal, pre-create link to prevent alltoall out of memory
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
session_id = f"{req_meta.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]], [agent_meta.kv_caches_base_addr[0]], [128]
|
||||
)
|
||||
if ret < 0:
|
||||
logger.error(f"Mooncake transfer failed to create link to device {session_id}")
|
||||
req_meta_update.remote_te_rpc_port = self.remote_te_port[req_meta_update.remote_engine_id][
|
||||
req_meta_update.remote_port
|
||||
req_meta.remote_te_rpc_port = self.remote_te_port[req_meta.remote_engine_id][req_meta.remote_port]
|
||||
req_meta.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta.remote_engine_id][
|
||||
req_meta.remote_port
|
||||
]
|
||||
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id][
|
||||
req_meta_update.remote_port
|
||||
]
|
||||
return req_meta_update
|
||||
return req_meta
|
||||
|
||||
def send_done_send_signal(self, req_id, req_meta):
|
||||
external_req_id = get_external_request_id(req_id)
|
||||
@@ -1221,7 +1313,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
try:
|
||||
path = make_zmq_path("tcp", req_meta.remote_host, req_meta.remote_port)
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id))
|
||||
encoded_data = msg_encoder.encode((DONE_SENDING_MSG, external_req_id, req_meta.trans_count))
|
||||
with zmq_ctx(zmq.REQ, path) as sock: # type: ignore
|
||||
ensure_zmq_send(sock, encoded_data, f"{req_meta.remote_host}:{req_meta.remote_port}")
|
||||
ack = sock.recv()
|
||||
|
||||
Reference in New Issue
Block a user