[PD] Add support for dp attention with mooncake (#5530)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -109,6 +109,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
# for p/d multi node infer
|
||||
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
||||
self.dist_init_addr = server_args.dist_init_addr
|
||||
self.tp_size = server_args.tp_size
|
||||
self.dp_size = server_args.dp_size
|
||||
self.enable_dp_attention = server_args.enable_dp_attention
|
||||
if not server_args.enable_dp_attention and server_args.dp_size != 1:
|
||||
raise ValueError(
|
||||
"If dp_attention is not enabled, dp size must be 1 in disaggregation mode."
|
||||
)
|
||||
self.request_status: Dict[int, KVPoll] = {}
|
||||
self.rank_port = None
|
||||
self.server_socket = zmq.Context().socket(zmq.PULL)
|
||||
@@ -121,6 +128,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.start_decode_thread()
|
||||
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
||||
self.prefill_dp_size_table: Dict[str, int] = {}
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
||||
@@ -331,6 +339,8 @@ class MooncakeKVManager(BaseKVManager):
|
||||
url = f"http://{bootstrap_server_url}/route"
|
||||
payload = {
|
||||
"role": "Prefill",
|
||||
"tp_size": self.tp_size,
|
||||
"dp_size": self.dp_size,
|
||||
"rank_ip": get_local_ip_by_remote(),
|
||||
"rank_port": self.rank_port,
|
||||
"engine_rank": self.kv_args.engine_rank,
|
||||
@@ -408,12 +418,41 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.session_id = self.kv_mgr.get_session_id()
|
||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
||||
|
||||
if not self.kv_mgr.enable_dp_attention:
|
||||
# We assume dp_attention should be activated simultaneously for
|
||||
# both prefill role and decode role. If the decode instance does
|
||||
# not enable dp_attention, then dp_attention is not enabled on the
|
||||
# prefill instance as well. Therefore, we should skip questioning
|
||||
# the prefill dp size to reduce bootstrap overhead.
|
||||
self.prefill_dp_size = 1
|
||||
elif self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
||||
self.prefill_dp_size, tp_size_per_dp_rank = (
|
||||
self._get_prefill_dp_size_from_server()
|
||||
)
|
||||
# Currently, we don't allow prefill instance and decode instance to
|
||||
# have different TP sizes per DP rank.
|
||||
assert tp_size_per_dp_rank == self.kv_mgr.tp_size // self.kv_mgr.dp_size
|
||||
if self.prefill_dp_size is None:
|
||||
logger.error(
|
||||
f"Could not fetch prefill dp_size for bootstrap_addr: {self.bootstrap_addr}"
|
||||
)
|
||||
else:
|
||||
self.kv_mgr.prefill_dp_size_table[self.bootstrap_addr] = (
|
||||
self.prefill_dp_size
|
||||
)
|
||||
else:
|
||||
self.prefill_dp_size = self.kv_mgr.prefill_dp_size_table[
|
||||
self.bootstrap_addr
|
||||
]
|
||||
|
||||
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
||||
|
||||
if bootstrap_key not in self.kv_mgr.connection_pool:
|
||||
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
||||
self.kv_mgr.kv_args.engine_rank
|
||||
self.kv_mgr.kv_args.engine_rank,
|
||||
self.target_dp_group,
|
||||
)
|
||||
if self.bootstrap_info is None:
|
||||
logger.error(
|
||||
@@ -427,10 +466,10 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
assert self.bootstrap_info is not None
|
||||
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
||||
|
||||
def _get_bootstrap_info_from_server(self, engine_rank):
|
||||
def _get_bootstrap_info_from_server(self, engine_rank, target_dp_group):
|
||||
"""Fetch the bootstrap info from the bootstrap server."""
|
||||
try:
|
||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}&target_dp_group={target_dp_group}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
bootstrap_info = response.json()
|
||||
@@ -444,6 +483,25 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
def _get_prefill_dp_size_from_server(self) -> int:
|
||||
"""Fetch the prefill parallel info from the bootstrap server."""
|
||||
try:
|
||||
url = f"http://{self.bootstrap_addr}/route?engine_rank={-1}&target_dp_group={-1}"
|
||||
response = requests.get(url)
|
||||
if response.status_code == 200:
|
||||
prefill_parallel_info = response.json()
|
||||
return int(prefill_parallel_info["prefill_dp_size"]), int(
|
||||
prefill_parallel_info["tp_size_per_dp_rank"]
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to get prefill parallel info: {response.status_code}, {response.text}"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching prefill parallel info from bootstrap: {e}")
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _connect(cls, endpoint: str):
|
||||
with cls._global_lock:
|
||||
@@ -497,7 +555,9 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||
self.store = dict()
|
||||
self.lock = asyncio.Lock()
|
||||
self._setup_routes()
|
||||
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
|
||||
self.dp_size = None
|
||||
self.tp_size_per_dp_rank = None
|
||||
self.prefill_port_table: Dict[int, Dict[int, Dict[str, Union[str, int]]]] = {}
|
||||
|
||||
# Start bootstrap server
|
||||
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
||||
@@ -523,35 +583,64 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
||||
async def _handle_route_put(self, request: web.Request):
|
||||
data = await request.json()
|
||||
role = data["role"]
|
||||
tp_size = data["tp_size"]
|
||||
dp_size = data["dp_size"]
|
||||
rank_ip = data["rank_ip"]
|
||||
rank_port = int(data["rank_port"])
|
||||
engine_rank = int(data["engine_rank"])
|
||||
|
||||
if self.dp_size is None:
|
||||
self.dp_size = dp_size
|
||||
|
||||
tp_size_per_dp_rank = tp_size // dp_size
|
||||
if self.tp_size_per_dp_rank == None:
|
||||
self.tp_size_per_dp_rank = tp_size_per_dp_rank
|
||||
|
||||
# Add lock to make sure thread-safe
|
||||
if role == "Prefill":
|
||||
self.prefill_port_table[engine_rank] = {
|
||||
dp_group = engine_rank // tp_size_per_dp_rank
|
||||
tp_rank_in_dp_group = engine_rank % tp_size_per_dp_rank
|
||||
|
||||
async with self.lock:
|
||||
if dp_group not in self.prefill_port_table:
|
||||
self.prefill_port_table[dp_group] = {}
|
||||
|
||||
self.prefill_port_table[dp_group][tp_rank_in_dp_group] = {
|
||||
"rank_ip": rank_ip,
|
||||
"rank_port": rank_port,
|
||||
}
|
||||
logger.debug(
|
||||
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||
f"Registered Prefill bootstrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
||||
)
|
||||
|
||||
return web.Response(text="OK", status=200)
|
||||
|
||||
async def _handle_route_get(self, request: web.Request):
|
||||
engine_rank = request.query.get("engine_rank")
|
||||
if not engine_rank:
|
||||
return web.Response(text="Missing rank", status=400)
|
||||
target_dp_group = request.query.get("target_dp_group")
|
||||
if not engine_rank or not target_dp_group:
|
||||
return web.Response(text="Missing inputs for bootstrap server.", status=400)
|
||||
|
||||
# Currently we use engine_rank == -1 and target_dp_group == -1 to sync dp size
|
||||
if int(engine_rank) == -1 and int(target_dp_group) == -1:
|
||||
prefill_parallel_info = {
|
||||
"prefill_dp_size": self.dp_size,
|
||||
"tp_size_per_dp_rank": self.tp_size_per_dp_rank,
|
||||
}
|
||||
return web.json_response(prefill_parallel_info, status=200)
|
||||
|
||||
# Find corresponding prefill info
|
||||
tp_rank_in_dp_group = int(engine_rank) % self.tp_size_per_dp_rank
|
||||
|
||||
async with self.lock:
|
||||
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
|
||||
bootstrap_info = self.prefill_port_table[int(target_dp_group)][
|
||||
tp_rank_in_dp_group
|
||||
]
|
||||
|
||||
if bootstrap_info is not None:
|
||||
return web.json_response(bootstrap_info, status=200)
|
||||
else:
|
||||
return web.Response(text="Not Found", status=404)
|
||||
return web.Response(text="Bootstrap info not Found", status=404)
|
||||
|
||||
def _run_server(self):
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user