[v0.18.0][BugFix][P/D]Fix layerwise connector out of memory during large buffer transfer (#7752)
### What this PR does / why we need Fix layerwise connector out of memory during large buffer transfer. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? By nightly. --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com> Signed-off-by: Mengqing Cao <cmq0113@163.com> Co-authored-by: Mengqing Cao <cmq0113@163.com>
This commit is contained in:
@@ -212,6 +212,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
num_head_replica: int,
|
num_head_replica: int,
|
||||||
layer_metadata: dict[str, LayerMetadata],
|
layer_metadata: dict[str, LayerMetadata],
|
||||||
use_mla: bool,
|
use_mla: bool,
|
||||||
|
use_attn_mamba_hybrid: bool,
|
||||||
k_buffer: torch.Tensor,
|
k_buffer: torch.Tensor,
|
||||||
v_buffer: torch.Tensor,
|
v_buffer: torch.Tensor,
|
||||||
enable_kv_quant: bool,
|
enable_kv_quant: bool,
|
||||||
@@ -239,10 +240,17 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
self.layer_metadata = layer_metadata
|
self.layer_metadata = layer_metadata
|
||||||
self.total_layers = total_layers
|
self.total_layers = total_layers
|
||||||
self.use_mla = use_mla
|
self.use_mla = use_mla
|
||||||
|
self.use_attn_mamba_hybrid = use_attn_mamba_hybrid
|
||||||
self.resharding_stream = resharding_stream
|
self.resharding_stream = resharding_stream
|
||||||
self.current_layer = -1
|
self.current_layer = -1
|
||||||
|
|
||||||
self.send_queue = queue.Queue[SendTask]()
|
send_queue_size = 0
|
||||||
|
if self.pd_head_ratio != 1:
|
||||||
|
if self.use_attn_mamba_hybrid:
|
||||||
|
send_queue_size = len(self.kv_cache_specs)
|
||||||
|
else:
|
||||||
|
send_queue_size = 1
|
||||||
|
self.send_queue = queue.Queue[SendTask](maxsize=send_queue_size)
|
||||||
self.k_buffer = k_buffer
|
self.k_buffer = k_buffer
|
||||||
self.v_buffer = v_buffer
|
self.v_buffer = v_buffer
|
||||||
self.enable_kv_quant = enable_kv_quant
|
self.enable_kv_quant = enable_kv_quant
|
||||||
@@ -465,6 +473,7 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
|
|
||||||
for session_id, transfer_meta in session_meta.items():
|
for session_id, transfer_meta in session_meta.items():
|
||||||
if len(transfer_meta.src) > 0:
|
if len(transfer_meta.src) > 0:
|
||||||
|
req_start_time = time.perf_counter()
|
||||||
ret = self.engine.batch_transfer_sync_write(
|
ret = self.engine.batch_transfer_sync_write(
|
||||||
session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length
|
session_id, transfer_meta.src, transfer_meta.dst, transfer_meta.length
|
||||||
)
|
)
|
||||||
@@ -480,6 +489,16 @@ class KVCacheSendingLayerThread(threading.Thread):
|
|||||||
req_id, req_meta, layer_group_idx
|
req_id, req_meta, layer_group_idx
|
||||||
) # TODO Send a signal indicating transmission failure
|
) # TODO Send a signal indicating transmission failure
|
||||||
else:
|
else:
|
||||||
|
req_end_time = time.perf_counter()
|
||||||
|
total_transfer_size = sum(transfer_meta.length) / 1024
|
||||||
|
req_transfer_elapsed = (req_end_time - req_start_time) * 1000
|
||||||
|
logger.debug(
|
||||||
|
"Layer%d KV cache transfer task %dKB to remote_session_id [%s] took %.3f ms.",
|
||||||
|
send_task.layer_idx,
|
||||||
|
total_transfer_size,
|
||||||
|
session_id,
|
||||||
|
req_transfer_elapsed,
|
||||||
|
)
|
||||||
if send_task.layer_idx == (self.total_layers - 1):
|
if send_task.layer_idx == (self.total_layers - 1):
|
||||||
for req_id in transfer_meta.req_ids:
|
for req_id in transfer_meta.req_ids:
|
||||||
req_meta = send_task.send_request[req_id]
|
req_meta = send_task.send_request[req_id]
|
||||||
@@ -996,6 +1015,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
self.total_layers = vllm_config.model_config.get_num_layers(vllm_config.parallel_config)
|
||||||
self.use_mla = self.vllm_config.model_config.use_mla
|
self.use_mla = self.vllm_config.model_config.use_mla
|
||||||
|
self.use_attn_mamba_hybrid = False
|
||||||
self.request_map = dict[str, str]()
|
self.request_map = dict[str, str]()
|
||||||
if self.use_mla:
|
if self.use_mla:
|
||||||
self.total_num_kv_heads = 1
|
self.total_num_kv_heads = 1
|
||||||
@@ -1074,7 +1094,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
buffer_list.append(self.k_buffer)
|
buffer_list.append(self.k_buffer)
|
||||||
buffer_list.append(self.v_buffer)
|
buffer_list.append(self.v_buffer)
|
||||||
if self.enable_kv_quant:
|
if self.enable_kv_quant:
|
||||||
quant_k_cache_numel = first_kv_cache_tuple[0].numel() // 2
|
quant_k_cache_numel = first_kv_cache_tuple[0].numel()
|
||||||
self.k_quant_buffer = torch.zeros(
|
self.k_quant_buffer = torch.zeros(
|
||||||
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
|
quant_k_cache_numel + alignment, dtype=torch.int8, device=first_kv_cache.device
|
||||||
)
|
)
|
||||||
@@ -1093,10 +1113,8 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
|
|
||||||
for tensor in buffer_list:
|
for tensor in buffer_list:
|
||||||
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
assert tensor.data_ptr() % alignment == 0, "The address of the registered kv cache should be aligned to 2M"
|
||||||
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel())
|
ret_value = self.engine.register_memory(tensor.data_ptr(), tensor.numel() * tensor.element_size())
|
||||||
logger.info(
|
logger.info(f"Register memory buffer for transfer, buffer size:{tensor.numel() * tensor.element_size()}")
|
||||||
f"Register memory for prefill when pd head ratio > 1 {tensor.data_ptr()} {tensor.numel()} {ret_value=}"
|
|
||||||
)
|
|
||||||
if ret_value != 0:
|
if ret_value != 0:
|
||||||
raise RuntimeError("Mooncake memory registration failed. ")
|
raise RuntimeError("Mooncake memory registration failed. ")
|
||||||
|
|
||||||
@@ -1108,7 +1126,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
for layer_name in kv_cache_group_spec.layer_names:
|
for layer_name in kv_cache_group_spec.layer_names:
|
||||||
layer2group_ids[layer_name] = i
|
layer2group_ids[layer_name] = i
|
||||||
|
|
||||||
use_mamba, use_attn, use_attn_mamba_hybrid = False, False, False
|
use_mamba, use_attn = False, False
|
||||||
conv_total_padding_size = 0
|
conv_total_padding_size = 0
|
||||||
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||||
for layer_name in kv_cache_tensor.shared_by:
|
for layer_name in kv_cache_tensor.shared_by:
|
||||||
@@ -1122,7 +1140,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
if isinstance(layer_kv_cache_spec, AttentionSpec):
|
if isinstance(layer_kv_cache_spec, AttentionSpec):
|
||||||
use_attn = True
|
use_attn = True
|
||||||
if use_mamba and use_attn:
|
if use_mamba and use_attn:
|
||||||
use_attn_mamba_hybrid = True
|
self.use_attn_mamba_hybrid = True
|
||||||
break
|
break
|
||||||
|
|
||||||
ptrs = []
|
ptrs = []
|
||||||
@@ -1158,7 +1176,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape))
|
single_layer_meta.block_len.append(single_kv_cache.element_size() * math.prod(block_shape))
|
||||||
single_layer_meta.block_size_scale.append(block_size_scale)
|
single_layer_meta.block_size_scale.append(block_size_scale)
|
||||||
self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale
|
self.kernel_block_size_scale[layer2group_ids[layer_name]] = block_size_scale
|
||||||
if single_kv_cache.data_ptr() not in ptrs and not use_attn_mamba_hybrid:
|
if single_kv_cache.data_ptr() not in ptrs and not self.use_attn_mamba_hybrid:
|
||||||
ptrs.append(single_kv_cache.data_ptr())
|
ptrs.append(single_kv_cache.data_ptr())
|
||||||
lengths.append(
|
lengths.append(
|
||||||
num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale
|
num_blocks * single_kv_cache.element_size() * math.prod(block_shape) * block_size_scale
|
||||||
@@ -1166,7 +1184,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}")
|
logger.info(f"layer: {layer_name}, num_blocks: {num_blocks}, block_shape: {block_shape}")
|
||||||
self.layer_metadata[layer_name] = single_layer_meta
|
self.layer_metadata[layer_name] = single_layer_meta
|
||||||
|
|
||||||
if use_attn_mamba_hybrid:
|
if self.use_attn_mamba_hybrid:
|
||||||
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
for kv_cache_tensor in self.kv_cache_config.kv_cache_tensors:
|
||||||
tensor_addrs = []
|
tensor_addrs = []
|
||||||
for layer_name in kv_cache_tensor.shared_by:
|
for layer_name in kv_cache_tensor.shared_by:
|
||||||
@@ -1217,6 +1235,7 @@ class MooncakeLayerwiseConnectorWorker:
|
|||||||
num_head_replica=self.num_head_replica,
|
num_head_replica=self.num_head_replica,
|
||||||
layer_metadata=self.layer_metadata,
|
layer_metadata=self.layer_metadata,
|
||||||
use_mla=self.use_mla,
|
use_mla=self.use_mla,
|
||||||
|
use_attn_mamba_hybrid=self.use_attn_mamba_hybrid,
|
||||||
k_buffer=self.k_buffer,
|
k_buffer=self.k_buffer,
|
||||||
v_buffer=self.v_buffer,
|
v_buffer=self.v_buffer,
|
||||||
enable_kv_quant=self.enable_kv_quant,
|
enable_kv_quant=self.enable_kv_quant,
|
||||||
|
|||||||
Reference in New Issue
Block a user