[P/D] layerwise connector supports DeepSeek-V3.2 sparse attention && Distribute transfer tasks to redundant kv_head cards (#5722)
### What this PR does / why we need it?
Add new function to mooncake layerwise connector, including:
1. supports sparse attention, for DeepSeek-V3.2
2. Distribute transfer tasks to redundant kv_head cards
This PR is related to [[RFC]: CDCP Scheduling for Disaggregated
Prefilling with KV Cache Layerwise Push
Support](https://github.com/vllm-project/vllm-ascend/issues/4842)
### Does this PR introduce _any_ user-facing change?
No.
### How was this patch tested?
By CI.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
---------
Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: liziyu <liziyu16@huawei.com>
Co-authored-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -170,36 +170,6 @@ class TestKVCacheSendingLayerThread(unittest.TestCase):
|
|||||||
self.thread._transfer_kv_cache(send_task)
|
self.thread._transfer_kv_cache(send_task)
|
||||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
self.engine.batch_transfer_sync_write.assert_not_called()
|
||||||
|
|
||||||
def test_transfer_skips_when_tp_not_sender(self):
|
|
||||||
|
|
||||||
thread = KVCacheSendingLayerThread(
|
|
||||||
engine=self.engine,
|
|
||||||
total_layers=2,
|
|
||||||
ready_event=self.ready_event,
|
|
||||||
tp_rank=1,
|
|
||||||
pd_head_ratio=1,
|
|
||||||
num_head_replica=2,
|
|
||||||
kv_cache_base_addr=[1000, 2000, 3000, 4000],
|
|
||||||
use_mla=False,
|
|
||||||
block_len=[1024],
|
|
||||||
decode_tp_size=1,
|
|
||||||
first_kv_cache=self.first_kv_cache,
|
|
||||||
k_buffer=MagicMock(),
|
|
||||||
v_buffer=MagicMock(),
|
|
||||||
resharding_stream=MagicMock(),
|
|
||||||
callback_func=MagicMock())
|
|
||||||
req_meta = self.req_meta_base
|
|
||||||
send_task = SendTask(
|
|
||||||
send_request={"req3": req_meta},
|
|
||||||
wait_event=MagicMock(),
|
|
||||||
k_cache=self.key,
|
|
||||||
v_cache=self.value,
|
|
||||||
layer_idx=1,
|
|
||||||
rearrange_block_ids=[],
|
|
||||||
)
|
|
||||||
thread._transfer_kv_cache(send_task)
|
|
||||||
self.engine.batch_transfer_sync_write.assert_not_called()
|
|
||||||
|
|
||||||
@patch(
|
@patch(
|
||||||
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
|
"vllm_ascend.distributed.mooncake_layerwise_connector.group_concurrent_contiguous",
|
||||||
side_effect=group_concurrent_contiguous)
|
side_effect=group_concurrent_contiguous)
|
||||||
@@ -425,6 +395,7 @@ class MockVllmConfig:
|
|||||||
self.parallel_config.data_parallel_size = 1
|
self.parallel_config.data_parallel_size = 1
|
||||||
self.parallel_config.data_parallel_rank = 0
|
self.parallel_config.data_parallel_rank = 0
|
||||||
self.cache_config.block_size = 16
|
self.cache_config.block_size = 16
|
||||||
|
self.model_config.hf_config.num_key_value_heads = 1
|
||||||
|
|
||||||
self.kv_transfer_config.engine_id = "test_engine"
|
self.kv_transfer_config.engine_id = "test_engine"
|
||||||
self.kv_transfer_config.kv_port = 5000
|
self.kv_transfer_config.kv_port = 5000
|
||||||
|
|||||||
@@ -108,6 +108,7 @@ class AscendSFAMetadata:
|
|||||||
# chunked prefill by default if no attn_states passed
|
# chunked prefill by default if no attn_states passed
|
||||||
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
attn_state: AscendAttentionState = AscendAttentionState.ChunkedPrefill
|
||||||
sfa_cp_context: Optional[SfaCpContext] = None
|
sfa_cp_context: Optional[SfaCpContext] = None
|
||||||
|
reshape_cache_event: torch.npu.Event = None
|
||||||
|
|
||||||
|
|
||||||
M = TypeVar("M", bound=AscendSFAMetadata)
|
M = TypeVar("M", bound=AscendSFAMetadata)
|
||||||
@@ -369,6 +370,7 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
self.enable_sfa_cp = enable_dsa_cp()
|
self.enable_sfa_cp = enable_dsa_cp()
|
||||||
self.local_num_heads = self.num_heads
|
self.local_num_heads = self.num_heads
|
||||||
self.vllm_config = get_current_vllm_config()
|
self.vllm_config = get_current_vllm_config()
|
||||||
|
self.is_kv_producer = self.vllm_config.kv_transfer_config is not None and self.vllm_config.kv_transfer_config.is_kv_producer
|
||||||
if self.enable_sfa_cp:
|
if self.enable_sfa_cp:
|
||||||
self.local_num_heads = self.num_heads * self.tp_size
|
self.local_num_heads = self.num_heads * self.tp_size
|
||||||
self.layer_sharding_kwargs = []
|
self.layer_sharding_kwargs = []
|
||||||
@@ -897,11 +899,15 @@ class AscendSFAImpl(MLAAttentionImpl):
|
|||||||
k = get_tp_group().all_gather(k, 0)
|
k = get_tp_group().all_gather(k, 0)
|
||||||
|
|
||||||
if kv_cache is not None:
|
if kv_cache is not None:
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event = torch.npu.Event()
|
||||||
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
torch_npu.npu_scatter_nd_update_(kv_cache[2].view(-1, k.shape[-1]),
|
||||||
attn_metadata.slot_mapping.view(
|
attn_metadata.slot_mapping.view(
|
||||||
-1, 1),
|
-1, 1),
|
||||||
k.view(-1,
|
k.view(-1,
|
||||||
k.shape[-1])) # b, s, n, d
|
k.shape[-1])) # b, s, n, d
|
||||||
|
if self.is_kv_producer:
|
||||||
|
attn_metadata.reshape_cache_event.record()
|
||||||
|
|
||||||
weights, _ = self.weights_proj(x)
|
weights, _ = self.weights_proj(x)
|
||||||
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
weights = torch.ops.vllm.maybe_all_gather_and_maybe_unpad(
|
||||||
|
|||||||
@@ -162,6 +162,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
self.kv_caches_base_addr = kv_cache_base_addr
|
self.kv_caches_base_addr = kv_cache_base_addr
|
||||||
self.total_layers = total_layers
|
self.total_layers = total_layers
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
|
self.use_sparse = len(block_len) == 3
|
||||||
self.block_len = block_len
|
self.block_len = block_len
|
||||||
self._decode_tp_size = decode_tp_size
|
self._decode_tp_size = decode_tp_size
|
||||||
self.resharding_stream = resharding_stream
|
self.resharding_stream = resharding_stream
|
||||||
@@ -195,17 +196,6 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
src_list: list[str] = []
|
src_list: list[str] = []
|
||||||
dst_list: list[str] = []
|
dst_list: list[str] = []
|
||||||
length_list: list[int] = []
|
length_list: list[int] = []
|
||||||
# not need to send kv cache
|
|
||||||
if self.tp_rank % self.num_head_replica != 0:
|
|
||||||
logger.debug(
|
|
||||||
f"Cancelling KV cache transfer for request {req_id}. Reason: TP rank excluded from head replication (TP Rank: {self.tp_rank}, Replicas: {self.num_head_replica})."
|
|
||||||
)
|
|
||||||
return (src_list, dst_list, length_list)
|
|
||||||
if self.use_mla and self.tp_rank >= self._decode_tp_size:
|
|
||||||
logger.debug(
|
|
||||||
f"Cancelling KV cache transfer for request {req_id}. Reason: MLA mode active and TP rank outside decoding group (TP Rank: {self.tp_rank}, Decode TP Size: {self._decode_tp_size})."
|
|
||||||
)
|
|
||||||
return (src_list, dst_list, length_list)
|
|
||||||
|
|
||||||
layer_idx = send_task.layer_idx
|
layer_idx = send_task.layer_idx
|
||||||
remote_block_ids = req_meta.remote_block_ids
|
remote_block_ids = req_meta.remote_block_ids
|
||||||
@@ -214,21 +204,36 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
local_block_ids = req_meta.local_block_ids
|
local_block_ids = req_meta.local_block_ids
|
||||||
|
|
||||||
if self.pd_head_ratio == 1:
|
if self.pd_head_ratio == 1:
|
||||||
layer_local_kv_base_addr = [
|
if self.use_sparse:
|
||||||
local_kv_base_addr[i]
|
layer_local_kv_base_addr = [
|
||||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
local_kv_base_addr[i] for i in
|
||||||
]
|
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||||
layer_remote_kv_base_addr = [
|
]
|
||||||
remote_kv_base_addrs[i] # type:ignore
|
layer_remote_kv_base_addr = [
|
||||||
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
remote_kv_base_addrs[i] # type:ignore
|
||||||
]
|
for i in
|
||||||
|
[3 * layer_idx, 3 * layer_idx + 1, 3 * layer_idx + 2]
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
layer_local_kv_base_addr = [
|
||||||
|
local_kv_base_addr[i]
|
||||||
|
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||||
|
]
|
||||||
|
layer_remote_kv_base_addr = [
|
||||||
|
remote_kv_base_addrs[i] # type:ignore
|
||||||
|
for i in [2 * layer_idx, 2 * layer_idx + 1]
|
||||||
|
]
|
||||||
grouped_remote_block_ids, grouped_local_block_ids = \
|
grouped_remote_block_ids, grouped_local_block_ids = \
|
||||||
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
group_concurrent_contiguous(remote_block_ids, local_block_ids)
|
||||||
|
|
||||||
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
for k, (src_layer_base_addr, dst_layer_base_addr) in enumerate(
|
||||||
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
zip(layer_local_kv_base_addr, layer_remote_kv_base_addr)):
|
||||||
block_len = self.block_len[
|
if self.use_mla:
|
||||||
k % 2] if self.use_mla else self.block_len[0]
|
block_len = (self.block_len[k % 2])
|
||||||
|
elif self.use_sparse:
|
||||||
|
block_len = (self.block_len[k % 3])
|
||||||
|
else:
|
||||||
|
block_len = (self.block_len[0])
|
||||||
for group_remote_block_id, group_local_block_id in zip(
|
for group_remote_block_id, group_local_block_id in zip(
|
||||||
grouped_remote_block_ids, grouped_local_block_ids):
|
grouped_remote_block_ids, grouped_local_block_ids):
|
||||||
src = src_layer_base_addr + group_local_block_id[
|
src = src_layer_base_addr + group_local_block_id[
|
||||||
@@ -931,7 +936,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
|
|
||||||
# TODO(tms): Find a more robust way to detect and handle MLA
|
# TODO(tms): Find a more robust way to detect and handle MLA
|
||||||
self.use_mla = first_kv_cache_tuple[0].size(
|
self.use_mla = first_kv_cache_tuple[0].size(
|
||||||
-1) != first_kv_cache_tuple[1].size(-1)
|
-1) != first_kv_cache_tuple[1].size(-1) and len(
|
||||||
|
first_kv_cache_tuple) == 2
|
||||||
|
self.use_sparse = len(first_kv_cache_tuple) == 3
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
# MLA case.[num_block, block_size, 1, hidden_dim]
|
# MLA case.[num_block, block_size, 1, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@@ -945,6 +952,21 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
logger.info(
|
logger.info(
|
||||||
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s",
|
||||||
self.num_blocks, block_shape_norm, block_shape_pe)
|
self.num_blocks, block_shape_norm, block_shape_pe)
|
||||||
|
elif self.use_sparse:
|
||||||
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
|
block_rank = 3 # [block_size, latent_dim]
|
||||||
|
block_shape_norm = first_kv_cache_tuple[0].shape[-block_rank:]
|
||||||
|
block_shape_pe = first_kv_cache_tuple[1].shape[-block_rank:]
|
||||||
|
block_shape_k = first_kv_cache_tuple[2].shape[-block_rank:]
|
||||||
|
self.block_len = [
|
||||||
|
first_kv_cache[0].element_size() * math.prod(block_shape_norm),
|
||||||
|
first_kv_cache[1].element_size() * math.prod(block_shape_pe),
|
||||||
|
first_kv_cache[2].element_size() * math.prod(block_shape_k)
|
||||||
|
]
|
||||||
|
logger.info(
|
||||||
|
"num_blocks: %s, block_shape_norm: %s, block_shape_pe: %s, block_shape_k: %s",
|
||||||
|
self.num_blocks, block_shape_norm, block_shape_pe,
|
||||||
|
block_shape_k)
|
||||||
else:
|
else:
|
||||||
# [num_block, block_size, num_head, hidden_dim]
|
# [num_block, block_size, num_head, hidden_dim]
|
||||||
self.num_blocks = first_kv_cache.shape[0]
|
self.num_blocks = first_kv_cache.shape[0]
|
||||||
@@ -955,8 +977,9 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
logger.info("num_blocks: %s, block_shape: %s", self.num_blocks,
|
||||||
block_shape)
|
block_shape)
|
||||||
|
|
||||||
logger.info("Registering KV_Caches. use_mla: %s, shape %s",
|
logger.info(
|
||||||
self.use_mla, first_kv_cache.shape)
|
"Registering KV_Caches. use_mla: %s, use_sparse: %s, shape %s",
|
||||||
|
self.use_mla, self.use_sparse, first_kv_cache.shape)
|
||||||
|
|
||||||
self.kv_caches = kv_caches
|
self.kv_caches = kv_caches
|
||||||
kv_caches_base_addr = []
|
kv_caches_base_addr = []
|
||||||
@@ -971,9 +994,17 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
kv_caches_base_addr.append(base_addr)
|
kv_caches_base_addr.append(base_addr)
|
||||||
ptrs.append(base_addr)
|
ptrs.append(base_addr)
|
||||||
lengths.append(region_len)
|
lengths.append(region_len)
|
||||||
|
elif self.use_sparse:
|
||||||
|
for i, cache in enumerate(cache_or_caches, 0):
|
||||||
|
base_addr = cache.data_ptr()
|
||||||
|
region_len = self.num_blocks * self.block_len[i % 3]
|
||||||
|
kv_caches_base_addr.append(base_addr)
|
||||||
|
ptrs.append(base_addr)
|
||||||
|
lengths.append(region_len)
|
||||||
else:
|
else:
|
||||||
cache_list = [cache_or_caches
|
cache_list = [
|
||||||
] if self.use_mla else cache_or_caches
|
cache_or_caches
|
||||||
|
] if self.use_mla or self.use_sparse else cache_or_caches
|
||||||
for cache in cache_list:
|
for cache in cache_list:
|
||||||
base_addr = cache.data_ptr()
|
base_addr = cache.data_ptr()
|
||||||
region_len = self.num_blocks * self.block_len[0]
|
region_len = self.num_blocks * self.block_len[0]
|
||||||
@@ -1046,56 +1077,72 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
if self.vllm_config.kv_transfer_config.is_kv_producer and connector_metadata.requests.keys(
|
||||||
):
|
):
|
||||||
# enable decode prefix cache
|
# enable decode prefix cache
|
||||||
if self.use_mla:
|
if self.use_mla or self.use_sparse:
|
||||||
reshape_cache_event = attn_metadata[
|
num_kv_head = self._decode_tp_size
|
||||||
layer_name].reshape_cache_event
|
|
||||||
else:
|
else:
|
||||||
reshape_cache_event = attn_metadata.reshape_cache_event
|
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||||
|
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
|
||||||
|
replica_group_idx = self.tp_rank % num_replica_groups
|
||||||
|
req_ids = sorted(list(connector_metadata.requests.keys()))
|
||||||
|
selected_req_ids = [
|
||||||
|
req_id for i, req_id in enumerate(req_ids)
|
||||||
|
if i % num_replica_groups == replica_group_idx
|
||||||
|
]
|
||||||
|
if selected_req_ids:
|
||||||
|
if self.use_mla or self.use_sparse:
|
||||||
|
reshape_cache_event = attn_metadata[
|
||||||
|
layer_name].reshape_cache_event
|
||||||
|
else:
|
||||||
|
reshape_cache_event = attn_metadata.reshape_cache_event
|
||||||
|
|
||||||
if self.pd_head_ratio != 1:
|
if self.pd_head_ratio != 1:
|
||||||
assert self.resharding_stream is not None
|
assert self.resharding_stream is not None
|
||||||
with npu_stream_switch(self.resharding_stream):
|
with npu_stream_switch(self.resharding_stream):
|
||||||
reshape_cache_event.wait()
|
reshape_cache_event.wait()
|
||||||
rearrange_block_ids = sorted({
|
rearrange_block_ids = sorted({
|
||||||
block_id
|
block_id
|
||||||
for request in connector_metadata.requests.values()
|
for req_id in selected_req_ids
|
||||||
for block_id in request.local_block_ids
|
for block_id in
|
||||||
})
|
connector_metadata.requests[req_id].local_block_ids
|
||||||
|
})
|
||||||
|
|
||||||
keys = kv_layer[0][rearrange_block_ids].clone()
|
keys = kv_layer[0][rearrange_block_ids].clone()
|
||||||
values = kv_layer[1][rearrange_block_ids].clone()
|
values = kv_layer[1][rearrange_block_ids].clone()
|
||||||
# sort kv caches for each block
|
# sort kv caches for each block
|
||||||
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
keys = keys.view(keys.size(0), self.pd_head_ratio, -1,
|
||||||
*keys.shape[2:]).transpose(
|
*keys.shape[2:]).transpose(
|
||||||
0, 1).reshape_as(keys)
|
0, 1).reshape_as(keys)
|
||||||
values = values.view(values.size(0), self.pd_head_ratio,
|
values = values.view(values.size(0),
|
||||||
-1, *values.shape[2:]).transpose(
|
self.pd_head_ratio, -1,
|
||||||
0, 1).reshape_as(values)
|
*values.shape[2:]).transpose(
|
||||||
# reshard kv cache
|
0, 1).reshape_as(values)
|
||||||
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
# reshard kv cache
|
||||||
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
keys = keys.reshape(-1, *kv_layer[0].shape[2:])
|
||||||
(keys, values) = kv_alltoall_and_rearrange(
|
values = values.reshape(-1, *kv_layer[1].shape[2:])
|
||||||
self.pd_head_ratio, keys, values)
|
(keys, values) = kv_alltoall_and_rearrange(
|
||||||
else:
|
self.pd_head_ratio, keys, values)
|
||||||
keys = None
|
else:
|
||||||
values = None
|
keys = None
|
||||||
rearrange_block_ids = None
|
values = None
|
||||||
|
rearrange_block_ids = None
|
||||||
|
|
||||||
assert self.kv_send_layer_thread is not None
|
assert self.kv_send_layer_thread is not None
|
||||||
assert reshape_cache_event is not None
|
assert reshape_cache_event is not None
|
||||||
send_task = SendTask(wait_event=reshape_cache_event,
|
send_task = SendTask(wait_event=reshape_cache_event,
|
||||||
k_cache=keys,
|
k_cache=keys,
|
||||||
v_cache=values,
|
v_cache=values,
|
||||||
layer_idx=self.current_layer,
|
layer_idx=self.current_layer,
|
||||||
rearrange_block_ids=rearrange_block_ids)
|
rearrange_block_ids=rearrange_block_ids)
|
||||||
for req_id, req_meta in connector_metadata.requests.items():
|
for req_id, req_meta in connector_metadata.requests.items():
|
||||||
req_meta_update = self.update_decoder_info(req_id, req_meta)
|
if req_id in selected_req_ids:
|
||||||
logger.debug(
|
req_meta_update = self.update_decoder_info(
|
||||||
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
req_id, req_meta)
|
||||||
)
|
logger.debug(
|
||||||
send_task.send_request[req_id] = req_meta_update
|
f"Add request {req_id} to kv send layer thread. {req_meta_update=}"
|
||||||
|
)
|
||||||
|
send_task.send_request[req_id] = req_meta_update
|
||||||
|
|
||||||
self.kv_send_layer_thread.send_queue.put(send_task)
|
self.kv_send_layer_thread.send_queue.put(send_task)
|
||||||
self.current_layer += 1
|
self.current_layer += 1
|
||||||
|
|
||||||
def _get_remote_socket(
|
def _get_remote_socket(
|
||||||
@@ -1121,8 +1168,13 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
|
|
||||||
def update_decoder_info(self, req_id, req_meta):
|
def update_decoder_info(self, req_id, req_meta):
|
||||||
req_meta_update = copy.deepcopy(req_meta)
|
req_meta_update = copy.deepcopy(req_meta)
|
||||||
req_meta_update.remote_port = req_meta_update.remote_port + (
|
if self.use_mla or self.use_sparse:
|
||||||
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
pd_tp_ratio = self.tp_size // self._decode_tp_size
|
||||||
|
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||||
|
self.tp_rank // pd_tp_ratio) % self._decode_tp_size
|
||||||
|
else:
|
||||||
|
req_meta_update.remote_port = req_meta_update.remote_port + (
|
||||||
|
self.tp_rank // self.pd_tp_ratio) % self._decode_tp_size
|
||||||
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
if req_meta_update.remote_engine_id not in self.remote_kv_caches_base_addr or \
|
||||||
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
req_meta_update.remote_port not in self.remote_kv_caches_base_addr[req_meta_update.remote_engine_id]:
|
||||||
try:
|
try:
|
||||||
@@ -1146,14 +1198,16 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
logger.info(
|
logger.info(
|
||||||
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
f"Query to port and kv base addr for request {req_id} from {req_meta_update.remote_host}:{req_meta_update.remote_port} success {agent_meta.kv_caches_base_addr=} {agent_meta.te_rpc_port=}"
|
||||||
)
|
)
|
||||||
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
if self.pd_head_ratio > 1:
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
# for tp inequal, pre-create link to prevent alltoall out of memory
|
||||||
session_id, [self.kv_caches_base_addr[0]],
|
session_id = f"{req_meta_update.remote_host}:{agent_meta.te_rpc_port}"
|
||||||
[agent_meta.kv_caches_base_addr[0]], [128])
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
if ret < 0:
|
session_id, [self.kv_caches_base_addr[0]],
|
||||||
logger.error(
|
[agent_meta.kv_caches_base_addr[0]], [128])
|
||||||
f"Mooncake transfer failed to create link to device {session_id}"
|
if ret < 0:
|
||||||
)
|
logger.error(
|
||||||
|
f"Mooncake transfer failed to create link to device {session_id}"
|
||||||
|
)
|
||||||
req_meta_update.remote_te_rpc_port = self.remote_te_port[
|
req_meta_update.remote_te_rpc_port = self.remote_te_port[
|
||||||
req_meta_update.remote_engine_id][req_meta_update.remote_port]
|
req_meta_update.remote_engine_id][req_meta_update.remote_port]
|
||||||
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
|
req_meta_update.remote_kv_caches_base_addr = self.remote_kv_caches_base_addr[
|
||||||
|
|||||||
Reference in New Issue
Block a user