add variable TP Decode > Prefill size support (#9960)
Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
@@ -168,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
|
||||
self.required_dst_info_num = 1
|
||||
self.target_tp_ranks = [self.target_tp_rank]
|
||||
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
||||
assert (
|
||||
self.kv_mgr.is_mla_backend
|
||||
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
||||
self.target_tp_rank = (
|
||||
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
||||
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
||||
|
||||
@@ -459,7 +459,9 @@ class MooncakeKVManager(BaseKVManager):
|
||||
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
||||
else:
|
||||
# Send KVCache from 1 prefill instance to multiple decode instances
|
||||
src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
|
||||
src_head_start_offset = (
|
||||
dst_tp_rank_in_group * dst_heads_per_rank
|
||||
) % src_heads_per_rank
|
||||
num_heads_to_send = dst_heads_per_rank
|
||||
dst_head_start_offset = 0
|
||||
|
||||
|
||||
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
|
||||
dst_kv_ptrs: list[int]
|
||||
dst_aux_ptrs: list[int]
|
||||
gpu_id: int
|
||||
decode_tp_size: int
|
||||
decode_tp_rank: int
|
||||
dst_kv_item_len: int
|
||||
|
||||
@classmethod
|
||||
def from_zmq(cls, msg: List[bytes]):
|
||||
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
|
||||
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
||||
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
||||
gpu_id=int(msg[7].decode("ascii")),
|
||||
decode_tp_size=int(msg[8].decode("ascii")),
|
||||
decode_tp_rank=int(msg[9].decode("ascii")),
|
||||
dst_kv_item_len=int(msg[10].decode("ascii")),
|
||||
)
|
||||
|
||||
|
||||
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
|
||||
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
||||
):
|
||||
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
|
||||
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
||||
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
||||
if not self.kv_descs:
|
||||
raise Exception("NIXL memory registration failed for kv tensors")
|
||||
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
|
||||
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
||||
):
|
||||
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
|
||||
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
||||
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
||||
if not self.aux_descs:
|
||||
raise Exception("NIXL memory registration failed for aux tensors")
|
||||
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
|
||||
logger.debug(
|
||||
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
||||
)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
|
||||
raise Exception("KVSender failed to post transfer")
|
||||
return xfer_handle
|
||||
|
||||
def send_kvcache_slice(
|
||||
self,
|
||||
peer_name: str,
|
||||
prefill_kv_indices: npt.NDArray[np.int32],
|
||||
dst_kv_ptrs: list[int],
|
||||
dst_kv_indices: npt.NDArray[np.int32],
|
||||
dst_gpu_id: int,
|
||||
notif: str,
|
||||
prefill_tp_size: int,
|
||||
decode_tp_size: int,
|
||||
decode_tp_rank: int,
|
||||
dst_kv_item_len: int,
|
||||
):
|
||||
# Get configuration from kv_args
|
||||
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
||||
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
||||
num_kv_heads = self.kv_args.kv_head_num
|
||||
|
||||
# Calculate head distribution
|
||||
src_heads_per_rank = num_kv_heads
|
||||
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
||||
|
||||
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
page_size = self.kv_args.page_size
|
||||
|
||||
bytes_per_head_slice_to_send = (
|
||||
dst_kv_item_len // page_size // dst_heads_per_rank
|
||||
)
|
||||
|
||||
# Determine which heads to send
|
||||
if prefill_tp_size > decode_tp_size:
|
||||
# Multiple prefill ranks to one decode rank
|
||||
src_head_start_offset = 0
|
||||
num_heads_to_send = src_heads_per_rank
|
||||
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
||||
else:
|
||||
# Send KVCache from 1 prefill instance to multiple decode instances
|
||||
src_head_start_offset = (
|
||||
dst_tp_rank_in_group * dst_heads_per_rank
|
||||
) % src_heads_per_rank
|
||||
num_heads_to_send = dst_heads_per_rank
|
||||
dst_head_start_offset = 0
|
||||
|
||||
# Create transfer descriptors
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
|
||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
||||
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within the token
|
||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
||||
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
||||
|
||||
src_dst_ptr_pairs = [
|
||||
(
|
||||
src_k_ptrs[layer_id],
|
||||
dst_k_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_k_ptrs))
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
dst_v_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_v_ptrs))
|
||||
]
|
||||
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
|
||||
# Calculate strides for a single token slot
|
||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||
|
||||
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
||||
for i in range(len(prefill_kv_indices)):
|
||||
prefill_page_idx = int(prefill_kv_indices[i])
|
||||
decode_page_idx = int(dst_kv_indices[i])
|
||||
|
||||
# Get the starting addresses for the current src and dst pages
|
||||
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
||||
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
||||
|
||||
# Iterate through each valid token slot within the current page
|
||||
for token_slot_in_page in range(page_size):
|
||||
# Calculate the start address of the current token slot
|
||||
src_token_slot_start_addr = (
|
||||
src_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_prefill
|
||||
)
|
||||
dst_token_slot_start_addr = (
|
||||
dst_page_start_addr
|
||||
+ token_slot_in_page * bytes_per_token_on_decode
|
||||
)
|
||||
|
||||
# Calculate final src and dst addresses by applying head-slice offsets
|
||||
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
||||
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
||||
|
||||
src_addrs.append(
|
||||
(
|
||||
src_slice_addr,
|
||||
heads_bytes_per_token_to_send,
|
||||
self.kv_args.gpu_id,
|
||||
)
|
||||
)
|
||||
dst_addrs.append(
|
||||
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
||||
)
|
||||
|
||||
# Use NIXL agent for transfer
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
||||
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
||||
)
|
||||
if not xfer_handle:
|
||||
raise Exception("Failed to create sliced KV transfer")
|
||||
|
||||
state = self.agent.transfer(xfer_handle)
|
||||
if state == "ERR":
|
||||
raise Exception("Failed to post sliced KV transfer")
|
||||
|
||||
return xfer_handle
|
||||
|
||||
def send_aux(
|
||||
self,
|
||||
peer_name: str,
|
||||
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
|
||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
||||
# Transfer data
|
||||
xfer_handle = self.agent.initialize_xfer(
|
||||
"WRITE",
|
||||
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
|
||||
assert req.agent_name in self.decode_kv_args_table
|
||||
|
||||
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
||||
kv_xfer_handle = self.send_kvcache(
|
||||
req.agent_name,
|
||||
kv_indices,
|
||||
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||
notif,
|
||||
)
|
||||
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
||||
|
||||
if decode_tp_size == self.tp_size:
|
||||
kv_xfer_handle = self.send_kvcache(
|
||||
req.agent_name,
|
||||
kv_indices,
|
||||
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||
notif,
|
||||
)
|
||||
else:
|
||||
kv_xfer_handle = self.send_kvcache_slice(
|
||||
req.agent_name,
|
||||
kv_indices,
|
||||
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
||||
chunked_dst_kv_indice,
|
||||
self.decode_kv_args_table[req.agent_name].gpu_id,
|
||||
notif,
|
||||
prefill_tp_size=self.tp_size,
|
||||
decode_tp_size=decode_tp_size,
|
||||
decode_tp_rank=self.decode_kv_args_table[
|
||||
req.agent_name
|
||||
].decode_tp_rank,
|
||||
dst_kv_item_len=self.decode_kv_args_table[
|
||||
req.agent_name
|
||||
].dst_kv_item_len,
|
||||
)
|
||||
|
||||
handles.append(kv_xfer_handle)
|
||||
# Only the last chunk we need to send the aux data.
|
||||
if is_last:
|
||||
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
||||
packed_kv_data_ptrs,
|
||||
packed_aux_data_ptrs,
|
||||
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
||||
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user