[PD] Raise error for incompatible mooncake version and some minor fixes (#7527)
Signed-off-by: Shangming Cai <caishangming@linux.alibaba.com>
This commit is contained in:
@@ -56,7 +56,7 @@ PD Disaggregation with Mooncake supports the following environment variables for
|
||||
|:--------:|:-----------:|:--------:
|
||||
| **`SGLANG_DISAGGREGATION_THREAD_POOL_SIZE`** | Controls the total number of worker threads for KVCache transfer operations per TP rank | A dynamic value calculated by `int(0.75 * os.cpu_count()) // 8)`, which is limited to be larger than 4 and less than 12 to ensure efficiency and prevent thread race conditions |
|
||||
| **`SGLANG_DISAGGREGATION_QUEUE_SIZE`** | Sets the number of parallel transfer queues. KVCache transfer requests from multiple decode instances will be sharded into these queues so that they can share the threads and the transfer bandwidth at the same time. If it is set to `1`, then we transfer requests one by one according to fcfs strategy | `4` |
|
||||
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `30` |
|
||||
| **`SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT`** | Timeout (seconds) for receiving destination KV indices during request initialization | `120` |
|
||||
|
||||
#### Decode Server Configuration
|
||||
| Variable | Description | Default |
|
||||
|
||||
@@ -187,7 +187,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
).start()
|
||||
|
||||
self.bootstrap_time_out = get_int_env_var(
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 30
|
||||
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 120
|
||||
)
|
||||
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
||||
self.heartbeat_failures = {}
|
||||
@@ -195,8 +195,8 @@ class MooncakeKVManager(BaseKVManager):
|
||||
self.session_pool_lock = threading.Lock()
|
||||
self.addr_to_rooms_tracker = defaultdict(set)
|
||||
self.connection_lock = threading.Lock()
|
||||
self.required_prefill_info_num_map: Dict[int, int] = {}
|
||||
self.decode_kv_arrive_state: Dict[int, Set[int]] = defaultdict(set)
|
||||
self.required_prefill_response_num_table: Dict[int, int] = {}
|
||||
self.prefill_response_tracker: Dict[int, Set[int]] = defaultdict(set)
|
||||
# Heartbeat interval should be at least 2 seconds
|
||||
self.heartbeat_interval = max(
|
||||
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
||||
@@ -311,22 +311,23 @@ class MooncakeKVManager(BaseKVManager):
|
||||
each page to ensure correctness for any page_size and head-slicing configuration.
|
||||
This may introduce performance overhead (increased TTFT) for long sequences.
|
||||
"""
|
||||
# rank/kv_head config
|
||||
# Extract configuration
|
||||
local_tp_rank = self.kv_args.engine_rank
|
||||
local_tp_size = self.tp_size // self.dp_size
|
||||
num_kv_heads = self.kv_args.kv_head_num
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
page_size = self.kv_args.page_size
|
||||
|
||||
# Calculate head distribution
|
||||
heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
|
||||
heads_per_prefill_rank = num_kv_heads
|
||||
decode_global_head_start = dst_tp_rank * heads_per_decode_rank
|
||||
prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
|
||||
bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
|
||||
|
||||
# decode config
|
||||
decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
|
||||
|
||||
# Determine slicing parameters based on TP configuration
|
||||
if local_tp_size > dst_tp_size:
|
||||
src_head_offset = 0
|
||||
num_heads_to_send = heads_per_prefill_rank
|
||||
@@ -340,7 +341,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
for layer_id in range(num_layers):
|
||||
item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
|
||||
|
||||
# Page stride on the target Decode rank for its slice pages
|
||||
# Page stride on the target dst decode rank for its slice pages
|
||||
item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
|
||||
|
||||
if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
|
||||
@@ -349,12 +350,12 @@ class MooncakeKVManager(BaseKVManager):
|
||||
)
|
||||
return -1
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within Prefill page data
|
||||
# Calculate precise byte offset and length for the sub-slice within the prefill page data
|
||||
src_slice_offset = src_head_offset * bytes_per_head
|
||||
dst_slice_offset = dst_head_offset * bytes_per_head
|
||||
slice_lens_per_page = num_heads_to_send * bytes_per_head
|
||||
|
||||
# Sanity check: The data sub-slice we intend to send should fit into D_n's page.
|
||||
# Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
|
||||
# This means slice_lens_per_page <= item_len_of_decode_rank_page
|
||||
if slice_lens_per_page > item_len_of_decode_rank_page:
|
||||
logger.error(
|
||||
@@ -365,15 +366,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
return -1
|
||||
layer_transfer_params.append(
|
||||
(
|
||||
self.kv_args.kv_data_ptrs[layer_id], # Prefill base ptr (all heads)
|
||||
dst_kv_ptrs[
|
||||
layer_id
|
||||
], # Decode base ptr (for its slice for this layer)
|
||||
item_len_of_prefill_rank_page, # Prefill page size (all heads)2048
|
||||
item_len_of_decode_rank_page, # Decode page stride (for its slice page) 1024
|
||||
src_slice_offset, # Offset to slice data in Prefill page
|
||||
dst_slice_offset, # Offset to slice data in Decode page
|
||||
slice_lens_per_page, # Length of slice data per page (actual data to send)
|
||||
self.kv_args.kv_data_ptrs[layer_id],
|
||||
dst_kv_ptrs[layer_id],
|
||||
item_len_of_prefill_rank_page,
|
||||
item_len_of_decode_rank_page,
|
||||
src_slice_offset,
|
||||
dst_slice_offset,
|
||||
slice_lens_per_page,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -399,13 +398,13 @@ class MooncakeKVManager(BaseKVManager):
|
||||
prefill_page_idx = int(prefill_kv_indices[i])
|
||||
decode_page_idx = int(dst_kv_indices[i])
|
||||
|
||||
# Get the starting memory address for the current source and destination pages
|
||||
# Get the starting addresses for the current src and dst pages
|
||||
src_page_start_addr = src_ptr + prefill_page_idx * src_item_len
|
||||
dst_page_start_addr = dst_ptr + decode_page_idx * dst_item_len
|
||||
|
||||
# Iterate through each valid token slot within the current page
|
||||
for token_slot_in_page in range(page_size):
|
||||
# Calculate start address of the current token slot
|
||||
# Calculate the start address of the current token slot
|
||||
src_token_slot_start_addr = (
|
||||
src_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_prefill
|
||||
@@ -415,7 +414,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
+ token_slot_in_page * bytes_per_token_on_decode
|
||||
)
|
||||
|
||||
# Calculate final source and destination addresses by applying head-slice offsets
|
||||
# Calculate final src and dst addresses by applying head-slice offsets
|
||||
src_slice_addr = src_token_slot_start_addr + src_offset
|
||||
dst_slice_addr = dst_token_slot_start_addr + dst_offset
|
||||
|
||||
@@ -585,9 +584,7 @@ class MooncakeKVManager(BaseKVManager):
|
||||
ret = self.send_aux(
|
||||
req.mooncake_session_id,
|
||||
kv_chunk.prefill_aux_index,
|
||||
self.decode_kv_args_table[
|
||||
req.mooncake_session_id
|
||||
].dst_aux_ptrs,
|
||||
target_rank_registration_info.dst_aux_ptrs,
|
||||
req.dst_aux_index,
|
||||
)
|
||||
polls.append(True if ret == 0 else False)
|
||||
@@ -675,19 +672,19 @@ class MooncakeKVManager(BaseKVManager):
|
||||
prefill_rank = int(prefill_rank.decode("ascii"))
|
||||
|
||||
if status == KVPoll.Success:
|
||||
# record arrived prefill_rank
|
||||
self.decode_kv_arrive_state[bootstrap_room].add(prefill_rank)
|
||||
expected_prefill_num = self.required_prefill_info_num_map[
|
||||
bootstrap_room
|
||||
]
|
||||
arrived_prefill_num = len(
|
||||
self.decode_kv_arrive_state[bootstrap_room]
|
||||
)
|
||||
if (
|
||||
self.is_mla_backend
|
||||
or arrived_prefill_num == expected_prefill_num
|
||||
):
|
||||
self.update_status(bootstrap_room, KVPoll.Success)
|
||||
if bootstrap_room in self.request_status:
|
||||
self.prefill_response_tracker[bootstrap_room].add(prefill_rank)
|
||||
expected_response_num = (
|
||||
self.required_prefill_response_num_table[bootstrap_room]
|
||||
)
|
||||
arrived_response_num = len(
|
||||
self.prefill_response_tracker[bootstrap_room]
|
||||
)
|
||||
if (
|
||||
self.is_mla_backend
|
||||
or arrived_response_num == expected_response_num
|
||||
):
|
||||
self.update_status(bootstrap_room, KVPoll.Success)
|
||||
elif status == KVPoll.Failed:
|
||||
self.record_failure(
|
||||
bootstrap_room,
|
||||
@@ -900,14 +897,13 @@ class MooncakeKVSender(BaseKVSender):
|
||||
self.aux_index = None
|
||||
self.bootstrap_server_url = bootstrap_addr
|
||||
self.conclude_state = None
|
||||
self.init_time = None
|
||||
self.init_time = time.time()
|
||||
# inner state
|
||||
self.curr_idx = 0
|
||||
|
||||
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
||||
self.num_kv_indices = num_kv_indices
|
||||
self.aux_index = aux_index
|
||||
self.init_time = time.time()
|
||||
|
||||
def send(
|
||||
self,
|
||||
@@ -1031,7 +1027,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||
)
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_info_num = 1
|
||||
self.required_prefill_response_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1044,7 +1040,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
self.required_dst_info_num = (
|
||||
local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank
|
||||
)
|
||||
self.required_prefill_info_num = 1
|
||||
self.required_prefill_response_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
else:
|
||||
if not self.kv_mgr.is_mla_backend:
|
||||
@@ -1067,7 +1063,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
# or the KVPoll will never be set correctly
|
||||
self.target_tp_rank = self.target_tp_ranks[0]
|
||||
self.required_dst_info_num = 1
|
||||
self.required_prefill_info_num = (
|
||||
self.required_prefill_response_num = (
|
||||
prefill_tp_size_per_dp_rank // local_tp_size_per_dp_rank
|
||||
)
|
||||
|
||||
@@ -1077,8 +1073,8 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
else:
|
||||
self.target_dp_group = bootstrap_room % self.prefill_dp_size
|
||||
|
||||
self.kv_mgr.required_prefill_info_num_map[self.bootstrap_room] = (
|
||||
self.required_prefill_info_num
|
||||
self.kv_mgr.required_prefill_response_num_table[self.bootstrap_room] = (
|
||||
self.required_prefill_response_num
|
||||
)
|
||||
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
||||
bootstrap_key = (
|
||||
@@ -1094,13 +1090,13 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
)
|
||||
if bootstrap_info is not None:
|
||||
if self.kv_mgr.is_mla_backend:
|
||||
# MLA :select one prefill rank as real rank
|
||||
# For MLA: target_tp_rank is the selected real rank, others are dummy ranks
|
||||
bootstrap_info["is_dummy"] = not bool(
|
||||
target_tp_rank == self.target_tp_rank
|
||||
or self.target_tp_rank is None
|
||||
)
|
||||
else:
|
||||
# no-MLA:select all prefill ranks
|
||||
# For non-MLA: all target_tp_ranks are selected real ranks
|
||||
bootstrap_info["is_dummy"] = False
|
||||
logger.debug(
|
||||
f"Fetched bootstrap info: {bootstrap_info} for DP {self.target_dp_group} TP {target_tp_rank}"
|
||||
@@ -1240,8 +1236,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
||||
def clear(self) -> None:
|
||||
if self.bootstrap_room in self.kv_mgr.request_status:
|
||||
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
||||
self.kv_mgr.required_prefill_info_num_map.pop(self.bootstrap_room)
|
||||
self.kv_mgr.decode_kv_arrive_state.pop(self.bootstrap_room)
|
||||
|
||||
if self.bootstrap_room in self.kv_mgr.required_prefill_response_num_table:
|
||||
self.kv_mgr.required_prefill_response_num_table.pop(self.bootstrap_room)
|
||||
|
||||
if self.bootstrap_room in self.kv_mgr.prefill_response_tracker:
|
||||
self.kv_mgr.prefill_response_tracker.pop(self.bootstrap_room)
|
||||
|
||||
def failure_exception(self):
|
||||
# Explicitly set the status to failure since this request has failed in another rank
|
||||
|
||||
@@ -97,13 +97,19 @@ class MooncakeTransferEngine:
|
||||
peer_buffer_addresses: List[int],
|
||||
lengths: List[int],
|
||||
) -> int:
|
||||
"""Synchronously transfer data to the specified address."""
|
||||
"""Synchronously transfer data to the specified addresses in batches."""
|
||||
try:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, buffers, peer_buffer_addresses, lengths
|
||||
)
|
||||
except Exception:
|
||||
ret = -1
|
||||
# Inform user to upgrade mooncake-transfer-engine >= 0.3.4.post2
|
||||
if not hasattr(self.engine, "batch_transfer_sync_write"):
|
||||
raise RuntimeError(
|
||||
"Mooncake's batch transfer requires mooncake-transfer-engine >= 0.3.4.post2. "
|
||||
"Please upgrade Mooncake by 'pip install mooncake-transfer-engine --upgrade'"
|
||||
)
|
||||
|
||||
if ret < 0:
|
||||
logger.debug(
|
||||
|
||||
Reference in New Issue
Block a user