[v0.18.0][BugFix][P/D]Fix layerwise connector out of memory during large buffer transfer (#7752)
### What this PR does / why we need Fix layerwise connector out of memory during large buffer transfer. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By nightly. --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -212,6 +212,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
num_head_replica: int,
|
||||
layer_metadata: dict[str, LayerMetadata],
|
||||
use_mla: bool,
|
||||
use_attn_mamba_hybrid: bool,
|
||||
k_buffer: torch.Tensor,
|
||||
v_buffer: torch.Tensor,
|
||||
enable_kv_quant: bool,
|
||||
@@ -239,10 +240,17 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
self.layer_metadata = layer_metadata
|
||||
self.total_layers = total_layers
|
||||
self.use_mla = use_mla
|
||||
self.use_attn_mamba_hybrid = use_attn_mamba_hybrid
|
||||
self.resharding_stream = resharding_stream
|
||||
self.current_layer = -1
|
||||
|
||||
self.send_queue = queue.Queue[SendTask]()
|
||||
send_queue_size = 0
|
||||
if self.pd_head_ratio != 1:
|
||||
if self.use_attn_mamba_hybrid:
|
||||
send_queue_size = len(self.kv_cache_specs)
|
||||
else:
|
||||
send_queue_size = 1
|
||||
self.send_queue = queue.Queue[SendTask](maxsize=send_queue_size)
|
||||
self.k_buffer = k_buffer
|
||||
self.v_buffer = v_buffer
|
||||
self.enable_kv_quant = enable_kv_quant
|
||||
@@ -465,6 +473,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
|
||||
for session_id, transfer_meta in session_meta.items():
|
||||
if len(transfer_meta.src) > 0:
|
||||
req_start_time = time.perf_counter()
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length
|
||||
)
|
||||
@@ -480,6 +489,16 @@ class KVCacheSendingLayerThread(threading.Thread):
|
||||
req_id, req_meta, layer_group_idx
|
||||
) # TODO Send a signal indicating transmission failure
|
||||
else:
|
||||
req_end_time = time.perf_counter()
|
||||
total_transfer_size = sum(transfer_meta.length) / 1024
|
||||
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
|
||||
logger.debug(
|
||||
"Layer%d KV cache transfer task %dKB to remote_session_id [%s] took %.3f ms.",
|
||||
send_task.layer_idx,
|
||||
total_transfer_size,
|
||||
session_id,
|
||||
req_transfer_elapsed,
|
||||
)
|
||||
if send_task.layer_idx == (self.total_layers - 1):
|
||||
for req_id in transfer_meta.req_ids:
|
||||
req_meta = send_task.send_request[req_id]
|
||||
@@ -996,6 +1015,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.side_channel_host = get_ip()
|
||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||
self.use_mla = self.vllm_config.model_config.use_mla
|
||||
self.use_attn_mamba_hybrid = False
|
||||
self.request_map = dict[str, str]()
|
||||
if self.use_mla:
|
||||
self.total_num_kv_heads = 1
|
||||
@@ -1074,7 +1094,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
buffer_list.append(self.k_buffer)
|
||||
buffer_list.append(self.v_buffer)
|
||||
if self.enable_kv_quant:
|
||||
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
|
||||
quant_k_cache_numel = first_kv_cache_tuple[0].numel()
|
||||
self.k_quant_buffer = torch.zeros(
|
||||
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
|
||||
)
|
||||
@@ -1093,10 +1113,8 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
for tensor in buffer_list:
|
||||
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
|
||||
logger.info(
|
||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
|
||||
)
|
||||
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel() * tensor.element_size())
|
||||
logger.info(f"Register memory buffer for transfer, buffer size:{tensor.numel() * tensor.element_size()}")
|
||||
if ret_value != 0:
|
||||
raise RuntimeError("Mooncake memory registration failed. ")
|
||||
|
||||
@@ -1108,7 +1126,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
for layer_name in kv_cache_group_spec.layer_names:
|
||||
layer2group_ids[layer_name] = i
|
||||
|
||||
use_mamba, use_attn, use_attn_mamba_hybrid = False, False, False
|
||||
use_mamba, use_attn = False, False
|
||||
conv_total_padding_size = 0
|
||||
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
@@ -1122,7 +1140,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
if isinstance(layer_kv_cache_spec, AttentionSpec):
|
||||
use_attn = True
|
||||
if use_mamba and use_attn:
|
||||
use_attn_mamba_hybrid = True
|
||||
self.use_attn_mamba_hybrid = True
|
||||
break
|
||||
|
||||
ptrs = []
|
||||
@@ -1158,7 +1176,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape))
|
||||
single_layer_meta.block_size_scale.append(block_size_scale)
|
||||
self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale
|
||||
if single_kv_cache.data_ptr() not in ptrs and not use_attn_mamba_hybrid:
|
||||
if single_kv_cache.data_ptr() not in ptrs and not self.use_attn_mamba_hybrid:
|
||||
ptrs.append(single_kv_cache.data_ptr())
|
||||
lengths.append(
|
||||
num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale
|
||||
@@ -1166,7 +1184,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}")
|
||||
self.layer_metadata[layer_name] = single_layer_meta
|
||||
|
||||
if use_attn_mamba_hybrid:
|
||||
if self.use_attn_mamba_hybrid:
|
||||
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||
tensor_addrs = []
|
||||
for layer_name in kv_cache_tensor.shared_by:
|
||||
@@ -1217,6 +1235,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
num_head_replica=self.num_head_replica,
|
||||
layer_metadata=self.layer_metadata,
|
||||
use_mla=self.use_mla,
|
||||
use_attn_mamba_hybrid=self.use_attn_mamba_hybrid,
|
||||
k_buffer=self.k_buffer,
|
||||
v_buffer=self.v_buffer,
|
||||
enable_kv_quant=self.enable_kv_quant,
|
||||
|
||||
Reference in New Issue
Block a user