[P/D] Performance enhancement of Layerwise connector in TP asymmetric scenarios (#5540)
### What this PR does / why we need it?
[P/D] Performance enhancement of Layerwise connector in TP asymmetric
scenarios
1. Session fusion: For transmission tasks at each layer, aggregate
transmission tasks with the same destination and merge them into a
single task for assignment.
2. Alltoall aggregation: For TP asymmetric scenarios, perform all
alltoall operations at once according to the block granularity for all
requests.
[RFC]: CDCP Scheduling for Disaggregated Prefilling with KV Cache
Layerwise Push Support
https://github.com/vllm-project/vllm-ascend/issues/4842
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
- vLLM version: v0.13.0
- vLLM main:
45c1ca1ca1
---------
Signed-off-by: liziyu <liziyu16@huawei.com>
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com>
Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
Co-authored-by: wangxiaoteng <wangxiaoteng@huawei.com>
This commit is contained in:
@@ -11,7 +11,7 @@ import time
|
||||
from collections import OrderedDict, defaultdict, deque
|
||||
from collections.abc import Iterator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, Optional, Tuple
|
||||
|
||||
import httpx
|
||||
@@ -19,7 +19,6 @@ import msgspec
|
||||
import numpy as np
|
||||
import numpy.typing as npt
|
||||
import torch
|
||||
import torch_npu
|
||||
import zmq
|
||||
from mooncake.engine import TransferEngine # type: ignore
|
||||
from vllm.config import VllmConfig
|
||||
@@ -37,6 +36,7 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
||||
from vllm_ascend.distributed.utils import (align_memory,
|
||||
get_transfer_timeout_value,
|
||||
kv_alltoall_and_rearrange)
|
||||
from vllm_ascend.utils import npu_stream_switch
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.attention.backends.abstract import AttentionMetadata
|
||||
@@ -56,7 +56,7 @@ class MooncakeAgentMetadata(msgspec.Struct, omit_defaults=True, dict=True):
|
||||
@dataclass
|
||||
class ReqMeta:
|
||||
local_block_ids: list[int]
|
||||
token_ids: list[int]
|
||||
token_ids: Optional[list[int]]
|
||||
# Not None if layer-wise is disabled
|
||||
remote_block_ids: list[int]
|
||||
remote_engine_id: Optional[str]
|
||||
@@ -68,6 +68,26 @@ class ReqMeta:
|
||||
chunk_finish: Optional[bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendTask:
|
||||
send_request: dict[str, ReqMeta] = field(default_factory=dict)
|
||||
# pd_head_ratio == 1 use
|
||||
wait_event: Optional[torch.npu.Event] = None
|
||||
# pd_head_ratio > 1 use
|
||||
k_cache: Optional[torch.Tensor] = None
|
||||
v_cache: Optional[torch.Tensor] = None
|
||||
layer_idx: int = 0
|
||||
rearrange_block_ids: Optional[list[int]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransferMeta:
|
||||
src: list[int]
|
||||
dst: list[int]
|
||||
length: list[int]
|
||||
req_ids: list[str]
|
||||
|
||||
|
||||
@dataclass
|
||||
class SendReqInfo:
|
||||
local_block_ids: list[int]
|
||||
@@ -116,19 +136,24 @@ class SizedDict(OrderedDict):
|
||||
|
||||
class KVCacheSendingLayerThread(threading.Thread):
|
||||
|
||||
def __init__(self,
|
||||
engine: TransferEngine,
|
||||
total_layers: int,
|
||||
ready_event: threading.Event,
|
||||
tp_rank: int,
|
||||
pd_head_ratio: int,
|
||||
num_head_replica: int,
|
||||
kv_cache_base_addr: list[int],
|
||||
use_mla: bool,
|
||||
block_len: list[int],
|
||||
decode_tp_size: int,
|
||||
first_kv_cache: torch.Tensor,
|
||||
callback_func: Callable[..., None] = lambda x: None):
|
||||
def __init__(
|
||||
self,
|
||||
engine: TransferEngine,
|
||||
total_layers: int,
|
||||
ready_event: threading.Event,
|
||||
tp_rank: int,
|
||||
pd_head_ratio: int,
|
||||
num_head_replica: int,
|
||||
kv_cache_base_addr: list[int],
|
||||
use_mla: bool,
|
||||
block_len: list[int],
|
||||
decode_tp_size: int,
|
||||
first_kv_cache: torch.Tensor,
|
||||
k_buffer: torch.Tensor,
|
||||
v_buffer: torch.Tensor,
|
||||
resharding_stream: torch.npu.Stream,
|
||||
callback_func: Callable[..., None] = lambda x: None,
|
||||
):
|
||||
super().__init__(daemon=True, name="KVCacheSendingLayerThread")
|
||||
self.engine = engine
|
||||
self.tp_rank = tp_rank
|
||||
@@ -139,39 +164,12 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.use_mla = use_mla
|
||||
self.block_len = block_len
|
||||
self._decode_tp_size = decode_tp_size
|
||||
self.model_stream = torch_npu.npu.current_stream()
|
||||
self.resharding_stream = resharding_stream
|
||||
self.current_layer = -1
|
||||
|
||||
if self.pd_head_ratio > 1:
|
||||
# regesit kv buffer for tp inequal
|
||||
alignment = 2 * 1024 * 1024
|
||||
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
|
||||
dtype=first_kv_cache.dtype,
|
||||
device=first_kv_cache.device)
|
||||
self.k_buffer = align_memory(
|
||||
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
|
||||
-1, first_kv_cache.shape[-1])
|
||||
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
|
||||
dtype=first_kv_cache.dtype,
|
||||
device=first_kv_cache.device)
|
||||
self.v_buffer = align_memory(
|
||||
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
|
||||
-1, first_kv_cache.shape[-1])
|
||||
|
||||
for tensor in (self.k_buffer, self.v_buffer):
|
||||
assert tensor.data_ptr(
|
||||
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
ret_value = self.engine.register_memory(
|
||||
tensor.data_ptr(), tensor.numel())
|
||||
logger.info(
|
||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
|
||||
self.send_queue = queue.Queue[Tuple[str, ReqMeta, int, torch.Tensor,
|
||||
torch.Tensor, torch.npu.Event]]()
|
||||
|
||||
self.send_queue = queue.Queue[SendTask]()
|
||||
self.k_buffer = k_buffer
|
||||
self.v_buffer = v_buffer
|
||||
self.ready_event = ready_event
|
||||
self.callback_func = callback_func
|
||||
|
||||
@@ -181,43 +179,36 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
torch.npu.set_device(device)
|
||||
self.ready_event.set()
|
||||
while True:
|
||||
req_id, req_meta, layer_index, key, value, reshape_cache_event = self.send_queue.get(
|
||||
)
|
||||
self._handle_request(req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event)
|
||||
send_task = self.send_queue.get()
|
||||
self._handle_request(send_task)
|
||||
|
||||
def _handle_request(self, req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event):
|
||||
def _handle_request(self, send_task: SendTask):
|
||||
try:
|
||||
logger.debug(
|
||||
f"Starting to transfer KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||
)
|
||||
self._transfer_kv_cache(req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event)
|
||||
logger.debug(
|
||||
f"Finished transferring KV cache for request {req_id} {req_meta.remote_te_rpc_port=}."
|
||||
)
|
||||
self._transfer_kv_cache(send_task)
|
||||
except Exception as e:
|
||||
logger.error("Failed to transfer KV cache for request "
|
||||
f"{req_id}: {e}")
|
||||
logger.error(
|
||||
f"Failed to transfer KV cache for layer idx {send_task.layer_idx}, {e}"
|
||||
)
|
||||
|
||||
def _transfer_kv_cache(self, req_id, req_meta, layer_index, key, value,
|
||||
reshape_cache_event):
|
||||
def get_transfer_meta(self, send_task: SendTask, req_id: str,
|
||||
req_meta: ReqMeta):
|
||||
src_list: list[str] = []
|
||||
dst_list: list[str] = []
|
||||
length_list: list[int] = []
|
||||
# not need to send kv cache
|
||||
if self.tp_rank % self.num_head_replica != 0:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
|
||||
)
|
||||
return
|
||||
return (src_list, dst_list, length_list)
|
||||
if self.use_mla and self.tp_rank >= self._decode_tp_size:
|
||||
logger.debug(
|
||||
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
|
||||
)
|
||||
return
|
||||
return (src_list, dst_list, length_list)
|
||||
|
||||
remote_host = req_meta.remote_host
|
||||
layer_idx = send_task.layer_idx
|
||||
remote_block_ids = req_meta.remote_block_ids
|
||||
remote_te_port = req_meta.remote_te_rpc_port
|
||||
remote_kv_base_addrs = req_meta.remote_kv_caches_base_addr
|
||||
local_kv_base_addr = self.kv_caches_base_addr
|
||||
local_block_ids = req_meta.local_block_ids
|
||||
@@ -225,17 +216,15 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
if self.pd_head_ratio == 1:
|
||||
layer_local_kv_base_addr = [
|
||||
local_kv_base_addr[i]
|
||||
for i in [2 * layer_index, 2 * layer_index + 1]
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i]
|
||||
for i in [2 * layer_index, 2 * layer_index + 1]
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
grouped_remote_block_ids, grouped_local_block_ids = \
|
||||
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
||||
|
||||
session_id = f"{remote_host}:{remote_te_port}"
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
||||
block_len = self.block_len[
|
||||
@@ -250,74 +239,101 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
src_list.append(src)
|
||||
dst_list.append(dst)
|
||||
length_list.append(length)
|
||||
if self.current_layer != layer_index:
|
||||
self.current_layer = layer_index
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
reshape_cache_event.synchronize()
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, src_list, dst_list, length_list)
|
||||
if ret < 0:
|
||||
logger.error("Mooncake transfer failed for request %s", req_id)
|
||||
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
||||
else:
|
||||
key = key.view(-1, key.shape[-1])
|
||||
value = value.view(-1, key.shape[-1])
|
||||
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
|
||||
self.v_buffer[:value.shape[0]].copy_(value)
|
||||
|
||||
rearrange_block_ids = send_task.rearrange_block_ids
|
||||
rearrange_block_dict = {
|
||||
value: index
|
||||
for index, value in enumerate(
|
||||
rearrange_block_ids) # type:ignore
|
||||
}
|
||||
layer_local_kv_base_addr = [
|
||||
self.k_buffer.data_ptr(),
|
||||
self.v_buffer.data_ptr()
|
||||
]
|
||||
|
||||
layer_remote_kv_base_addr = [
|
||||
remote_kv_base_addrs[i]
|
||||
for i in [2 * layer_index, 2 * layer_index + 1]
|
||||
remote_kv_base_addrs[i] # type:ignore
|
||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||
]
|
||||
|
||||
grouped_remote_block_ids, _ = group_concurrent_contiguous(
|
||||
remote_block_ids)
|
||||
|
||||
session_id = f"{remote_host}:{remote_te_port}"
|
||||
src_list, dst_list, length_list = [], [], []
|
||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
||||
src_layer_addr = src_layer_base_addr
|
||||
for group_remote_block_id in grouped_remote_block_ids:
|
||||
block_len = self.block_len[0]
|
||||
remote_block_len = self.block_len[0] * self.pd_head_ratio
|
||||
src_list.append(src_layer_addr)
|
||||
block_len = self.block_len[0]
|
||||
remote_block_len = self.block_len[0] * self.pd_head_ratio
|
||||
for remote_block_id, local_block_id in zip(
|
||||
remote_block_ids, local_block_ids):
|
||||
src = src_layer_base_addr + rearrange_block_dict[
|
||||
local_block_id] * block_len
|
||||
dst = dst_layer_base_addr + remote_block_id * remote_block_len + block_len * (
|
||||
(self.tp_rank // self.num_head_replica) %
|
||||
self.pd_head_ratio)
|
||||
src_list.append(src)
|
||||
dst_list.append(dst)
|
||||
length_list.append(block_len)
|
||||
return (src_list, dst_list, length_list)
|
||||
|
||||
if src_layer_addr + len(
|
||||
group_remote_block_id
|
||||
) * block_len > src_layer_base_addr + key.numel(
|
||||
) * key.element_size():
|
||||
length = src_layer_base_addr + key.numel(
|
||||
) * key.element_size() - src_layer_addr
|
||||
else:
|
||||
length = len(group_remote_block_id) * block_len
|
||||
length_list.append(length)
|
||||
def _transfer_kv_cache(self, send_task: SendTask):
|
||||
if self.pd_head_ratio > 1:
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
key = send_task.k_cache
|
||||
value = send_task.v_cache
|
||||
key = key.view(-1, key.shape[-1]) # type:ignore
|
||||
value = value.view(-1, key.shape[-1]) # type:ignore
|
||||
self.k_buffer[:key.shape[0]].copy_(key) # [:4, 128] ->
|
||||
self.v_buffer[:value.shape[0]].copy_(value)
|
||||
|
||||
dst_list.append(dst_layer_base_addr +
|
||||
group_remote_block_id[0] *
|
||||
remote_block_len + length *
|
||||
((self.tp_rank // self.num_head_replica) %
|
||||
self.pd_head_ratio))
|
||||
src_layer_addr += length
|
||||
self.model_stream.synchronize()
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, src_list, dst_list, length_list)
|
||||
if ret < 0:
|
||||
logger.error("Mooncake transfer failed for request %s", req_id)
|
||||
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
||||
# Merge transmission tasks of the same session
|
||||
session_meta: dict[str, TransferMeta] = {}
|
||||
for req_id, req_meta in send_task.send_request.items():
|
||||
session_id = f"{req_meta.remote_host}:{req_meta.remote_te_rpc_port}"
|
||||
if session_id not in session_meta.keys():
|
||||
session_meta[session_id] = TransferMeta(src=[],
|
||||
dst=[],
|
||||
length=[],
|
||||
req_ids=[])
|
||||
|
||||
if layer_index == (self.total_layers - 1) and req_meta.chunk_finish:
|
||||
self.callback_func(req_id, req_meta)
|
||||
(src_list, dst_list,
|
||||
length_list) = self.get_transfer_meta(send_task, req_id, req_meta)
|
||||
|
||||
session_meta[session_id].src.extend(src_list)
|
||||
session_meta[session_id].dst.extend(dst_list)
|
||||
session_meta[session_id].length.extend(length_list)
|
||||
session_meta[session_id].req_ids.append(req_id)
|
||||
|
||||
if self.pd_head_ratio == 1:
|
||||
"""
|
||||
Note: Due to a bug in ADXL, calling current_event.synchronize() may occasionally hang.
|
||||
This issue will be fixed in CANN version 8.5.rc1.
|
||||
You can manually build the master branch of the project at https://gitcode.com/cann/hixl
|
||||
to resolve this issue before the 8.5.RC1 release.
|
||||
"""
|
||||
send_task.wait_event.synchronize() # type:ignore
|
||||
elif self.pd_head_ratio > 1:
|
||||
self.resharding_stream.synchronize()
|
||||
|
||||
for session_id, transfer_meta in session_meta.items():
|
||||
if len(transfer_meta.src) > 0:
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, transfer_meta.src, transfer_meta.dst,
|
||||
transfer_meta.length)
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed for send requests {transfer_meta.req_ids} kv cache to {session_id}"
|
||||
)
|
||||
if send_task.layer_idx == (self.total_layers - 1):
|
||||
for req_id in transfer_meta.req_ids:
|
||||
req_meta = send_task.send_request[req_id]
|
||||
if req_meta.chunk_finish:
|
||||
self.callback_func(
|
||||
req_id, req_meta
|
||||
) # TODO Send a signal indicating transmission failure
|
||||
else:
|
||||
if send_task.layer_idx == (self.total_layers - 1):
|
||||
for req_id in transfer_meta.req_ids:
|
||||
req_meta = send_task.send_request[req_id]
|
||||
if req_meta.chunk_finish:
|
||||
self.callback_func(req_id, req_meta)
|
||||
|
||||
|
||||
class KVCacheRecvingLayerThread(threading.Thread):
|
||||
@@ -836,8 +852,10 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
self.pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
self.num_head_replica = get_ascend_config().num_head_replica
|
||||
self.resharding_stream = None
|
||||
if self.pd_head_ratio > 1:
|
||||
self.resharding_stream = torch.npu.Stream()
|
||||
|
||||
self.first_kv_cache = None
|
||||
self.remote_poller = zmq.Poller() # type: ignore
|
||||
self.decoder = msgspec.msgpack.Decoder(MooncakeAgentMetadata)
|
||||
self.encoder = msgspec.msgpack.Encoder()
|
||||
@@ -852,6 +870,8 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
deque)
|
||||
self.remote_poller = zmq.Poller() # type: ignore
|
||||
self.timeout = 1.0 # seconds
|
||||
self.k_buffer: Optional[torch.Tensor] = None
|
||||
self.v_buffer: Optional[torch.Tensor] = None
|
||||
|
||||
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
||||
# get prefill tp and dp size from extra config
|
||||
@@ -874,12 +894,40 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
assert "dp_size" in decode_parallel_config.keys()
|
||||
self._decode_dp_size = decode_parallel_config["dp_size"]
|
||||
|
||||
def create_kv_buffer(self, first_kv_cache):
|
||||
if self.pd_head_ratio > 1:
|
||||
# regesit kv buffer for tp inequal
|
||||
alignment = 2 * 1024 * 1024
|
||||
self.k_buffer = torch.zeros(first_kv_cache.numel() + alignment,
|
||||
dtype=first_kv_cache.dtype,
|
||||
device=first_kv_cache.device)
|
||||
self.k_buffer = align_memory(
|
||||
self.k_buffer, alignment)[:first_kv_cache.numel()].view(
|
||||
-1, first_kv_cache.shape[-1])
|
||||
self.v_buffer = torch.zeros(first_kv_cache.numel() + alignment,
|
||||
dtype=first_kv_cache.dtype,
|
||||
device=first_kv_cache.device)
|
||||
self.v_buffer = align_memory(
|
||||
self.v_buffer, alignment)[:first_kv_cache.numel()].view(
|
||||
-1, first_kv_cache.shape[-1])
|
||||
|
||||
for tensor in (self.k_buffer, self.v_buffer):
|
||||
assert tensor.data_ptr(
|
||||
) % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
ret_value = self.engine.register_memory(
|
||||
tensor.data_ptr(), tensor.numel())
|
||||
logger.info(
|
||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
|
||||
)
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
|
||||
def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
|
||||
"""Register the KV Cache data."""
|
||||
|
||||
_, first_kv_cache_tuple = next(iter(kv_caches.items()))
|
||||
first_kv_cache = first_kv_cache_tuple[0]
|
||||
self.first_kv_cache = first_kv_cache
|
||||
self.create_kv_buffer(first_kv_cache)
|
||||
|
||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||
self.use_mla = first_kv_cache_tuple[0].size(
|
||||
@@ -954,6 +1002,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
block_len=self.block_len,
|
||||
decode_tp_size=self._decode_tp_size,
|
||||
first_kv_cache=first_kv_cache,
|
||||
k_buffer=self.k_buffer,
|
||||
v_buffer=self.v_buffer,
|
||||
resharding_stream=self.resharding_stream,
|
||||
callback_func=self.send_done_send_signal)
|
||||
self.kv_send_layer_thread.start()
|
||||
ready_event.wait()
|
||||
@@ -1002,60 +1053,49 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||
|
||||
if self.pd_head_ratio != 1:
|
||||
assert self.resharding_stream is not None
|
||||
with npu_stream_switch(self.resharding_stream):
|
||||
reshape_cache_event.wait()
|
||||
rearrange_block_ids = sorted({
|
||||
block_id
|
||||
for request in connector_metadata.requests.values()
|
||||
for block_id in request.local_block_ids
|
||||
})
|
||||
|
||||
def sort_kv_cache(input_kv: list[list[int]]):
|
||||
return torch.cat([
|
||||
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x]
|
||||
for x in range(self.pd_head_ratio)
|
||||
for tensor in input_kv
|
||||
])
|
||||
|
||||
total_block_ids = [
|
||||
request.local_block_ids
|
||||
for request in connector_metadata.requests.values()
|
||||
]
|
||||
keys = [
|
||||
kv_layer[0][block_ids].reshape(
|
||||
-1, *kv_layer[0].shape[2:]).clone()
|
||||
for block_ids in total_block_ids
|
||||
]
|
||||
values = [
|
||||
kv_layer[1][block_ids].reshape(
|
||||
-1, *kv_layer[1].shape[2:]).clone()
|
||||
for block_ids in total_block_ids
|
||||
]
|
||||
key_block_size = keys[0].size(0) // len(total_block_ids[0])
|
||||
value_block_size = values[0].size(0) // len(total_block_ids[0])
|
||||
keys = sort_kv_cache(keys) # [req1_key, req2_key]
|
||||
values = sort_kv_cache(values)
|
||||
(keys,
|
||||
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
|
||||
values)
|
||||
key_start_id = 0
|
||||
value_start_id = 0
|
||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||
values = kv_layer[1][rearrange_block_ids].clone()
|
||||
# sort kv caches for each block
|
||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||
*keys.shape[2:]).transpose(
|
||||
0, 1).reshape_as(keys)
|
||||
values = values.view(values.size(0), self.pd_head_ratio,
|
||||
-1, *values.shape[2:]).transpose(
|
||||
0, 1).reshape_as(values)
|
||||
# reshard kv cache
|
||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||
(keys, values) = kv_alltoall_and_rearrange(
|
||||
self.pd_head_ratio, keys, values)
|
||||
else:
|
||||
key = None
|
||||
value = None
|
||||
keys = None
|
||||
values = None
|
||||
rearrange_block_ids = None
|
||||
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
send_task = SendTask(wait_event=reshape_cache_event,
|
||||
k_cache=keys,
|
||||
v_cache=values,
|
||||
layer_idx=self.current_layer,
|
||||
rearrange_block_ids=rearrange_block_ids)
|
||||
for req_id, req_meta in connector_metadata.requests.items():
|
||||
if self.pd_head_ratio != 1:
|
||||
key_block_num = len(
|
||||
req_meta.local_block_ids) * key_block_size
|
||||
value_block_num = len(
|
||||
req_meta.local_block_ids) * value_block_size
|
||||
key = keys[key_start_id:key_start_id + key_block_num]
|
||||
value = values[value_start_id:value_start_id +
|
||||
value_block_num]
|
||||
key_start_id += key_block_num
|
||||
value_start_id += value_block_num
|
||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
||||
logger.debug(
|
||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||
)
|
||||
assert self.kv_send_layer_thread is not None
|
||||
assert reshape_cache_event is not None
|
||||
self.kv_send_layer_thread.send_queue.put(
|
||||
(req_id, req_meta_update, self.current_layer, key, value,
|
||||
reshape_cache_event))
|
||||
send_task.send_request[req_id] = req_meta_update
|
||||
|
||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||
self.current_layer += 1
|
||||
|
||||
def _get_remote_socket(
|
||||
@@ -1106,6 +1146,14 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(
|
||||
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||
)
|
||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, [self.kv_caches_base_addr[0]],
|
||||
[agent_meta.kv_caches_base_addr[0]], 128)
|
||||
if ret < 0:
|
||||
logger.error(
|
||||
f"Mooncake transfer failed to create link to device {session_id}"
|
||||
)
|
||||
req_meta_update.remote_te_rpc_port = self.remote_te_port[
|
||||
req_meta_update.remote_engine_id][req_meta_update.remote_port]
|
||||
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
|
||||
|
||||
Reference in New Issue
Block a user