add code pp support for nixl (#11375)
Signed-off-by: Shahar Mor <smor@nvidia.com>
This commit is contained in:
@@ -319,14 +319,44 @@ class NixlKVManager(CommonKVManager):
|
||||
|
||||
logger.debug(f"sending kvcache to {peer_name} with notif {notif}")
|
||||
# Make descs
|
||||
num_layers = len(self.kv_args.kv_data_ptrs)
|
||||
if self.is_mla_backend:
|
||||
src_kv_ptrs, dst_kv_ptrs, layers_current_pp_stage = (
|
||||
self.get_mla_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
src_kv_ptrs[layer_id],
|
||||
dst_kv_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
]
|
||||
else:
|
||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
|
||||
kv_item_len = self.kv_args.kv_item_lens[0]
|
||||
layers_params = [
|
||||
(
|
||||
src_k_ptrs[layer_id],
|
||||
dst_k_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
dst_v_ptrs[layer_id],
|
||||
kv_item_len,
|
||||
)
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
]
|
||||
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
for layer_id in range(num_layers):
|
||||
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
||||
dst_ptr = dst_kv_ptrs[layer_id]
|
||||
item_len = self.kv_args.kv_item_lens[layer_id]
|
||||
|
||||
for src_ptr, dst_ptr, item_len in layers_params:
|
||||
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
||||
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
||||
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
||||
@@ -397,6 +427,9 @@ class NixlKVManager(CommonKVManager):
|
||||
num_heads_to_send = dst_heads_per_rank
|
||||
dst_head_start_offset = 0
|
||||
|
||||
src_k_ptrs, src_v_ptrs, dst_k_ptrs, dst_v_ptrs, layers_current_pp_stage = (
|
||||
self.get_mha_kv_ptrs_with_pp(self.kv_args.kv_data_ptrs, dst_kv_ptrs)
|
||||
)
|
||||
# Create transfer descriptors
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
@@ -404,12 +437,6 @@ class NixlKVManager(CommonKVManager):
|
||||
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
||||
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
||||
|
||||
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
||||
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
||||
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
||||
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
||||
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
||||
|
||||
# Calculate precise byte offset and length for the sub-slice within the token
|
||||
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
||||
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
||||
@@ -420,13 +447,13 @@ class NixlKVManager(CommonKVManager):
|
||||
src_k_ptrs[layer_id],
|
||||
dst_k_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_k_ptrs))
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
] + [
|
||||
(
|
||||
src_v_ptrs[layer_id],
|
||||
dst_v_ptrs[layer_id],
|
||||
)
|
||||
for layer_id in range(len(src_v_ptrs))
|
||||
for layer_id in range(layers_current_pp_stage)
|
||||
]
|
||||
|
||||
src_addrs = []
|
||||
@@ -496,14 +523,19 @@ class NixlKVManager(CommonKVManager):
|
||||
dst_aux_index: int,
|
||||
notif: str,
|
||||
):
|
||||
# Make descs
|
||||
aux_item_len = self.kv_args.aux_item_lens[0]
|
||||
prefill_aux_addr = (
|
||||
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
||||
)
|
||||
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
||||
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
||||
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
||||
src_addrs = []
|
||||
dst_addrs = []
|
||||
|
||||
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
||||
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
||||
|
||||
for i, _ in enumerate(dst_aux_ptrs):
|
||||
length = prefill_aux_item_lens[i]
|
||||
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
||||
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
||||
src_addrs.append((src_addr, length, 0))
|
||||
dst_addrs.append((dst_addr, length, 0))
|
||||
|
||||
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
||||
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
||||
# Transfer data
|
||||
@@ -576,7 +608,7 @@ class NixlKVManager(CommonKVManager):
|
||||
|
||||
handles.append(kv_xfer_handle)
|
||||
# Only the last chunk we need to send the aux data.
|
||||
if is_last:
|
||||
if is_last and self.pp_group.is_last_rank:
|
||||
assert aux_index is not None
|
||||
aux_xfer_handle = self.send_aux(
|
||||
req.agent_name,
|
||||
|
||||
Reference in New Issue
Block a user