add code pp support for nixl (#11375)
Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
|
|||||||
|
|
||||||
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
||||||
# Make descs
|
# Make descs
|
||||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
if self.is_mla_backend:
|
||||||
|
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||||
|
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
|
)
|
||||||
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
|
layers_params = [
|
||||||
|
(
|
||||||
|
src_kv_ptrs[layer_id],
|
||||||
|
dst_kv_ptrs[layer_id],
|
||||||
|
kv_item_len,
|
||||||
|
)
|
||||||
|
for layer_id in range(layers_current_pp_stage)
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||||
|
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
|
)
|
||||||
|
|
||||||
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||||
|
layers_params = [
|
||||||
|
(
|
||||||
|
src_k_ptrs[layer_id],
|
||||||
|
dst_k_ptrs[layer_id],
|
||||||
|
kv_item_len,
|
||||||
|
)
|
||||||
|
for layer_id in range(layers_current_pp_stage)
|
||||||
|
] + [
|
||||||
|
(
|
||||||
|
src_v_ptrs[layer_id],
|
||||||
|
dst_v_ptrs[layer_id],
|
||||||
|
kv_item_len,
|
||||||
|
)
|
||||||
|
for layer_id in range(layers_current_pp_stage)
|
||||||
|
]
|
||||||
|
|
||||||
src_addrs = []
|
src_addrs = []
|
||||||
dst_addrs = []
|
dst_addrs = []
|
||||||
for layer_id in range(num_layers):
|
for src_ptr, dst_ptr, item_len in layers_params:
|
||||||
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
|
||||||
dst_ptr = dst_kv_ptrs[layer_id]
|
|
||||||
item_len = self.kv_args.kv_item_lens[layer_id]
|
|
||||||
|
|
||||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||||
@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
|
|||||||
num_heads_to_send = dst_heads_per_rank
|
num_heads_to_send = dst_heads_per_rank
|
||||||
dst_head_start_offset = 0
|
dst_head_start_offset = 0
|
||||||
|
|
||||||
|
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||||
|
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||||
|
)
|
||||||
# Create transfer descriptors
|
# Create transfer descriptors
|
||||||
src_addrs = []
|
src_addrs = []
|
||||||
dst_addrs = []
|
dst_addrs = []
|
||||||
@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
|
|||||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||||
|
|
||||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
|
||||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
|
||||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
|
||||||
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
|
||||||
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
|
||||||
|
|
||||||
# Calculate precise byte offset and length for the sub-slice within the token
|
# Calculate precise byte offset and length for the sub-slice within the token
|
||||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||||
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
||||||
@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
|
|||||||
src_k_ptrs[layer_id],
|
src_k_ptrs[layer_id],
|
||||||
dst_k_ptrs[layer_id],
|
dst_k_ptrs[layer_id],
|
||||||
)
|
)
|
||||||
for layer_id in range(len(src_k_ptrs))
|
for layer_id in range(layers_current_pp_stage)
|
||||||
] + [
|
] + [
|
||||||
(
|
(
|
||||||
src_v_ptrs[layer_id],
|
src_v_ptrs[layer_id],
|
||||||
dst_v_ptrs[layer_id],
|
dst_v_ptrs[layer_id],
|
||||||
)
|
)
|
||||||
for layer_id in range(len(src_v_ptrs))
|
for layer_id in range(layers_current_pp_stage)
|
||||||
]
|
]
|
||||||
|
|
||||||
src_addrs = []
|
src_addrs = []
|
||||||
@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
|
|||||||
dst_aux_index: int,
|
dst_aux_index: int,
|
||||||
notif: str,
|
notif: str,
|
||||||
):
|
):
|
||||||
# Make descs
|
src_addrs = []
|
||||||
aux_item_len = self.kv_args.aux_item_lens[0]
|
dst_addrs = []
|
||||||
prefill_aux_addr = (
|
|
||||||
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||||
)
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
|
||||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
for i, _ in enumerate(dst_aux_ptrs):
|
||||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
length = prefill_aux_item_lens[i]
|
||||||
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||||
|
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||||
|
src_addrs.append((src_addr, length, 0))
|
||||||
|
dst_addrs.append((dst_addr, length, 0))
|
||||||
|
|
||||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
||||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
||||||
# Transfer data
|
# Transfer data
|
||||||
@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
|
|||||||
|
|
||||||
handles.append(kv_xfer_handle)
|
handles.append(kv_xfer_handle)
|
||||||
# Only the last chunk we need to send the aux data.
|
# Only the last chunk we need to send the aux data.
|
||||||
if is_last:
|
if is_last and self.pp_group.is_last_rank:
|
||||||
assert aux_index is not None
|
assert aux_index is not None
|
||||||
aux_xfer_handle = self.send_aux(
|
aux_xfer_handle = self.send_aux(
|
||||||
req.agent_name,
|
req.agent_name,
|
||||||
|
|||||||
Reference in New Issue
Block a user