[v0.18.0][BugFix][P/D]Fix layerwise connector out of memory during large buffer transfer (#7752)

### What this PR does / why we need 
Fix layerwise connector out of memory during large buffer transfer.

### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
By nightly.

---------

Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com>
Signed-off-by: Mengqing Cao <cmq0113@163.com>
Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
zxr2333
2026-03-31 22:16:53 +08:00
committed by GitHub
parent b1cc6ef6ae
commit ef9964389f

View File

@@ -212,6 +212,7 @@ class KVCacheSendingLayerThread(threading.Thread):
num_head_replica: int,
layer_metadata: dict[str, LayerMetadata],
use_mla: bool,
use_attn_mamba_hybrid: bool,
k_buffer: torch.Tensor,
v_buffer: torch.Tensor,
enable_kv_quant: bool,
@@ -239,10 +240,17 @@ class KVCacheSendingLayerThread(threading.Thread):
self.layer_metadata = layer_metadata
self.total_layers = total_layers
self.use_mla = use_mla
self.use_attn_mamba_hybrid = use_attn_mamba_hybrid
self.resharding_stream = resharding_stream
self.current_layer = -1
self.send_queue = queue.Queue[SendTask]()
send_queue_size = 0
if self.pd_head_ratio != 1:
if self.use_attn_mamba_hybrid:
send_queue_size = len(self.kv_cache_specs)
else:
send_queue_size = 1
self.send_queue = queue.Queue[SendTask](maxsize=send_queue_size)
self.k_buffer = k_buffer
self.v_buffer = v_buffer
self.enable_kv_quant = enable_kv_quant
@@ -465,6 +473,7 @@ class KVCacheSendingLayerThread(threading.Thread):
for session_id, transfer_meta in session_meta.items():
if len(transfer_meta.src) > 0:
req_start_time = time.perf_counter()
ret = self.engine.batch_transfer_sync_write(
session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length
)
@@ -480,6 +489,16 @@ class KVCacheSendingLayerThread(threading.Thread):
req_id, req_meta, layer_group_idx
) # TODO Send a signal indicating transmission failure
else:
req_end_time = time.perf_counter()
total_transfer_size = sum(transfer_meta.length) / 1024
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
logger.debug(
"Layer%d KV cache transfer task %dKB to remote_session_id [%s] took %.3f ms.",
send_task.layer_idx,
total_transfer_size,
session_id,
req_transfer_elapsed,
)
if send_task.layer_idx == (self.total_layers - 1):
for req_id in transfer_meta.req_ids:
req_meta = send_task.send_request[req_id]
@@ -996,6 +1015,7 @@ class MooncakeLayerwiseConnectorWorker:
self.side_channel_host = get_ip()
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
self.use_mla = self.vllm_config.model_config.use_mla
self.use_attn_mamba_hybrid = False
self.request_map = dict[str, str]()
if self.use_mla:
self.total_num_kv_heads = 1
@@ -1074,7 +1094,7 @@ class MooncakeLayerwiseConnectorWorker:
buffer_list.append(self.k_buffer)
buffer_list.append(self.v_buffer)
if self.enable_kv_quant:
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
quant_k_cache_numel = first_kv_cache_tuple[0].numel()
self.k_quant_buffer = torch.zeros(
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
)
@@ -1093,10 +1113,8 @@ class MooncakeLayerwiseConnectorWorker:
for tensor in buffer_list:
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
logger.info(
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
)
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel() * tensor.element_size())
logger.info(f"Register memory buffer for transfer, buffer size:{tensor.numel() * tensor.element_size()}")
if ret_value != 0:
raise RuntimeError("Mooncake memory registration failed. ")
@@ -1108,7 +1126,7 @@ class MooncakeLayerwiseConnectorWorker:
for layer_name in kv_cache_group_spec.layer_names:
layer2group_ids[layer_name] = i
use_mamba, use_attn, use_attn_mamba_hybrid = False, False, False
use_mamba, use_attn = False, False
conv_total_padding_size = 0
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
for layer_name in kv_cache_tensor.shared_by:
@@ -1122,7 +1140,7 @@ class MooncakeLayerwiseConnectorWorker:
if isinstance(layer_kv_cache_spec, AttentionSpec):
use_attn = True
if use_mamba and use_attn:
use_attn_mamba_hybrid = True
self.use_attn_mamba_hybrid = True
break
ptrs = []
@@ -1158,7 +1176,7 @@ class MooncakeLayerwiseConnectorWorker:
single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape))
single_layer_meta.block_size_scale.append(block_size_scale)
self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale
if single_kv_cache.data_ptr() not in ptrs and not use_attn_mamba_hybrid:
if single_kv_cache.data_ptr() not in ptrs and not self.use_attn_mamba_hybrid:
ptrs.append(single_kv_cache.data_ptr())
lengths.append(
num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale
@@ -1166,7 +1184,7 @@ class MooncakeLayerwiseConnectorWorker:
logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}")
self.layer_metadata[layer_name] = single_layer_meta
if use_attn_mamba_hybrid:
if self.use_attn_mamba_hybrid:
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
tensor_addrs = []
for layer_name in kv_cache_tensor.shared_by:
@@ -1217,6 +1235,7 @@ class MooncakeLayerwiseConnectorWorker:
num_head_replica=self.num_head_replica,
layer_metadata=self.layer_metadata,
use_mla=self.use_mla,
use_attn_mamba_hybrid=self.use_attn_mamba_hybrid,
k_buffer=self.k_buffer,
v_buffer=self.v_buffer,
enable_kv_quant=self.enable_kv_quant,