[PD] Raise error for incompatible mooncake version and some minor fixes (#7527)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
).start()
|
||||
|
||||
self.bootstrap_time_out = get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.heartbeat_failures = {}
|
||||
@@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.session_pool_lock = threading.Lock()
|
||||
self.addr_to_rooms_tracker = defaultdict(set)
|
||||
self.connection_lock = threading.Lock()
|
||||
self.required_prefill_info_num_map: Dict[int, int] = {}
|
||||
self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set)
|
||||
self.required_prefill_response_num_table: Dict[int, int] = {}
|
||||
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
||||
# Heartbeat interval should be at least 2 seconds
|
||||
self.heartbeat_interval = max(
|
||||
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
||||
@@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager):
|
||||
each page to ensure correctness for any page_size and head-slicing configuration.
|
||||
This may introduce performance overhead (increased TTFT) for long sequences.
|
||||
"""
|
||||
# rank/kv_head config
|
||||
# Extract configuration
|
||||
local_tp_rank = self.kv_args.engine_rank
|
||||
local_tp_size = self.tp_size // self.dp_size
|
||||
num_kv_heads = self.kv_args.kv_head_num
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
page_size = self.kv_args.page_size
|
||||
|
||||
# Calculate head distribution
|
||||
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
|
||||
heads_per_prefill_rank = num_kv_heads
|
||||
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
|
||||
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
|
||||
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
|
||||
|
||||
# decode config
|
||||
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
||||
|
||||
# Determine slicing parameters based on TP configuration
|
||||
if local_tp_size > dst_tp_size:
|
||||
src_head_offset = 0
|
||||
num_heads_to_send = heads_per_prefill_rank
|
||||
@@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
for layer_id in range(num_layers):
|
||||
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
|
||||
|
||||
# Page stride on the target Decode rank for its slice pages
|
||||
# Page stride on the target dst decode rank for its slice pages
|
||||
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
|
||||
|
||||
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
|
||||
@@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
return -1
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within Prefill page data
|
||||
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
||||
src_slice_offset = src_head_offset * bytes_per_head
|
||||
dst_slice_offset = dst_head_offset * bytes_per_head
|
||||
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
||||
|
||||
# Sanity check: The data sub-slice we intend to send should fit into D_n's page.
|
||||
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
|
||||
# This means slice_lens_per_page <= item_len_of_decode_rank_page
|
||||
if slice_lens_per_page > item_len_of_decode_rank_page:
|
||||
logger.error(
|
||||
@@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
return -1
|
||||
layer_transfer_params.append(
|
||||
(
|
||||
self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads)
|
||||
dst_kv_ptrs[
|
||||
layer_id
|
||||
], # Decode base ptr (for its slice for this layer)
|
||||
item_len_of_prefill_rank_page, # Prefill page size (all heads)2048
|
||||
item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024
|
||||
src_slice_offset, # Offset to slice data in Prefill page
|
||||
dst_slice_offset, # Offset to slice data in Decode page
|
||||
slice_lens_per_page, # Length of slice data per page (actual data to send)
|
||||
self.kv_args.kv_data_ptrs[layer_id],
|
||||
dst_kv_ptrs[layer_id],
|
||||
item_len_of_prefill_rank_page,
|
||||
item_len_of_decode_rank_page,
|
||||
src_slice_offset,
|
||||
dst_slice_offset,
|
||||
slice_lens_per_page,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
prefill_page_idx = int(prefill_kv_indices[i])
|
||||
decode_page_idx = int(dst_kv_indices[i])
|
||||
|
||||
# Get the starting memory address for the current source and destination pages
|
||||
# Get the starting addresses for the current src and dst pages
|
||||
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
|
||||
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
|
||||
|
||||
# Iterate through each valid token slot within the current page
|
||||
for token_slot_in_page in range(page_size):
|
||||
# Calculate start address of the current token slot
|
||||
# Calculate the start address of the current token slot
|
||||
src_token_slot_start_addr = (
|
||||
src_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_prefill
|
||||
@@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
+ token_slot_in_page * bytes_per_token_on_decode
|
||||
)
|
||||
|
||||
# Calculate final source and destination addresses by applying head-slice offsets
|
||||
# Calculate final src and dst addresses by applying head-slice offsets
|
||||
src_slice_addr = src_token_slot_start_addr + src_offset
|
||||
dst_slice_addr = dst_token_slot_start_addr + dst_offset
|
||||
|
||||
@@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_aux_ptrs,
|
||||
target_rank_registration_info.dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
@@ -675,19 +672,19 @@ class MooncakeKVManager(BaseKVManager):
|
||||
prefill_rank = int(prefill_rank.decode("ascii"))
|
||||
|
||||
if status == KVPoll.Success:
|
||||
# record arrived prefill_rank
|
||||
self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank)
|
||||
expected_prefill_num = self.required_prefill_info_num_map[
|
||||
bootstrap_room
|
||||
]
|
||||
arrived_prefill_num = len(
|
||||
self.decode_kv_arrive_state[bootstrap_room]
|
||||
)
|
||||
if (
|
||||
self.is_mla_backend
|
||||
or arrived_prefill_num == expected_prefill_num
|
||||
):
|
||||
self.update_status(bootstrap_room, KVPoll.Success)
|
||||
if bootstrap_room in self.request_status:
|
||||
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
|
||||
expected_response_num = (
|
||||
self.required_prefill_response_num_table[bootstrap_room]
|
||||
)
|
||||
arrived_response_num = len(
|
||||
self.prefill_response_tracker[bootstrap_room]
|
||||
)
|
||||
if (
|
||||
self.is_mla_backend
|
||||
or arrived_response_num == expected_response_num
|
||||
):
|
||||
self.update_status(bootstrap_room, KVPoll.Success)
|
||||
elif status == KVPoll.Failed:
|
||||
self.record_failure(
|
||||
bootstrap_room,
|
||||
@@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
|
||||
self.aux_index = None
|
||||
self.bootstrap_server_url = bootstrap_addr
|
||||
self.conclude_state = None
|
||||
self.init_time = None
|
||||
self.init_time = time.time()
|
||||
# inner state
|
||||
self.curr_idx = 0
|
||||
|
||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||
self.num_kv_indices = num_kv_indices
|
||||
self.aux_index = aux_index
|
||||
self.init_time = time.time()
|
||||
|
||||
def send(
|
||||
self,
|
||||
@@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||
)
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_info_num = 1
|
||||
self.required_prefill_response_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.required_dst_info_num = (
|
||||
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
||||
)
|
||||
self.required_prefill_info_num = 1
|
||||
self.required_prefill_response_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
else:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
# or the KVPoll will never be set correctly
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_info_num = (
|
||||
self.required_prefill_response_num = (
|
||||
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
|
||||
)
|
||||
|
||||
@@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = (
|
||||
self.required_prefill_info_num
|
||||
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
||||
self.required_prefill_response_num
|
||||
)
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
@@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
)
|
||||
if bootstrap_info is not None:
|
||||
if self.kv_mgr.is_mla_backend:
|
||||
# MLA :select one prefill rank as real rank
|
||||
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
||||
bootstrap_info["is_dummy"] = not bool(
|
||||
target_tp_rank == self.target_tp_rank
|
||||
or self.target_tp_rank is None
|
||||
)
|
||||
else:
|
||||
# no-MLA:select all prefill ranks
|
||||
# For non-MLA: all target_tp_ranks are selected real ranks
|
||||
bootstrap_info["is_dummy"] = False
|
||||
logger.debug(
|
||||
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
|
||||
@@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
def clear(self) -> None:
|
||||
if self.bootstrap_room in self.kv_mgr.request_status:
|
||||
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
||||
self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room)
|
||||
self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room)
|
||||
|
||||
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
|
||||
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
|
||||
|
||||
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
|
||||
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
|
||||
|
||||
def failure_exception(self):
|
||||
# Explicitly set the status to failure since this request has failed in another rank
|
||||
|
||||
@@ -97,13 +97,19 @@ class MooncakeTransferEngine:
|
||||
peer_buffer_addresses: List[int],
|
||||
lengths: List[int],
|
||||
) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
"""Synchronously transfer data to the specified addresses in batches."""
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, buffers, peer_buffer_addresses, lengths
|
||||
)
|
||||
except Exception:
|
||||
ret = -1
|
||||
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
|
||||
if not hasattr(self.engine, "batch_transfer_sync_write"):
|
||||
raise RuntimeError(
|
||||
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
|
||||
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
|
||||
)
|
||||
|
||||
if ret < 0:
|
||||
logger.debug(
|
||||
|
||||
Reference in New Issue
Block a user