[Bugfix] TP size larger than KV cache head causes accuracy issues (#3366)
### What this PR does / why we need it? Resolve the issue where, in the case of unequal TP (Tensor Parallelism), the TP size is larger than the number of model attention kvcache heads, causing the KV cache to generate duplicates, which leads to transmission errors in the original code. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? By ci - vLLM version: v0.11.0rc3 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0 --------- Signed-off-by: nwpu-zxr <zhouxuerong2@huawei.com> Signed-off-by: wangxiaoteng <wangxiaoteng@huawei.com> Co-authored-by: nwpu-zxr <zhouxuerong2@huawei.com>
This commit is contained in:
@@ -277,7 +277,7 @@ class SendingLayerThread(threading.Thread):
|
||||
self.send_queue = queue.Queue[tuple[DecodeMooncakeAgentMetadata, str,
|
||||
list[int], int, torch.Tensor,
|
||||
torch.Tensor]]()
|
||||
self.completion_event: threading.Event
|
||||
self.completion_event: Optional[threading.Event] = None
|
||||
self.completion_event_count: int
|
||||
self.task_tracker = task_tracker
|
||||
self.total_layers = total_layers
|
||||
@@ -287,6 +287,8 @@ class SendingLayerThread(threading.Thread):
|
||||
self.engine = engine
|
||||
self.tp_rank = tp_rank
|
||||
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
self.num_head_replica = get_ascend_config().num_head_replica
|
||||
self.pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
vllm_config = get_current_vllm_config()
|
||||
max_model_len = vllm_config.scheduler_config.max_model_len
|
||||
first_kv_cache = first_kv_cache[:max_model_len]
|
||||
@@ -358,7 +360,9 @@ class SendingLayerThread(threading.Thread):
|
||||
remote_kv_base_addrs = req_meta.kv_caches_base_addr
|
||||
|
||||
remote_block_ids = req_meta.block_ids
|
||||
if self.pd_tp_ratio == 1:
|
||||
if self.num_head_replica >= 1 and self.tp_rank % self.num_head_replica != 0:
|
||||
pass
|
||||
elif self.pd_head_ratio == 1:
|
||||
layer_local_kv_base_addr = [
|
||||
self.local_kv_base_addr[i]
|
||||
for i in [2 * layer_index, 2 * layer_index + 1]
|
||||
@@ -420,7 +424,7 @@ class SendingLayerThread(threading.Thread):
|
||||
src_layer_addr = src_layer_base_addr
|
||||
for group_remote_block_id in grouped_remote_block_ids:
|
||||
block_len = self.block_len[0]
|
||||
remote_block_len = self.block_len[0] * self.pd_tp_ratio
|
||||
remote_block_len = self.block_len[0] * self.pd_head_ratio
|
||||
src_list.append(src_layer_addr)
|
||||
|
||||
if src_layer_addr + len(
|
||||
@@ -436,23 +440,21 @@ class SendingLayerThread(threading.Thread):
|
||||
dst_list.append(dst_layer_base_addr +
|
||||
group_remote_block_id[0] *
|
||||
remote_block_len + length *
|
||||
(self.tp_rank % self.pd_tp_ratio))
|
||||
((self.tp_rank // self.num_head_replica) %
|
||||
self.pd_head_ratio))
|
||||
src_layer_addr += length
|
||||
torch.npu.synchronize()
|
||||
ret = self.engine.batch_transfer_sync_write(
|
||||
session_id, src_list, dst_list, length_list)
|
||||
self.completion_event_count -= 1
|
||||
|
||||
if self.completion_event_count == 0 and self.completion_event is not None:
|
||||
print(
|
||||
f"[_transfer_kv_cache] {self.completion_event_count} self.event.set()"
|
||||
)
|
||||
self.completion_event.set()
|
||||
|
||||
if ret < 0:
|
||||
logger.error("Mooncake transfer failed for request %s",
|
||||
req_meta.req_id)
|
||||
raise RuntimeError(f"Mooncake transfer failed, ret: {ret}")
|
||||
if self.completion_event is not None:
|
||||
self.completion_event_count -= 1
|
||||
if self.completion_event_count == 0:
|
||||
self.completion_event.set()
|
||||
self.completion_event = None
|
||||
|
||||
def add_event(self, event: threading.Event, count: int) -> None:
|
||||
self.completion_event = event
|
||||
@@ -924,6 +926,8 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
self.kv_caches_base_addr: list[int] = []
|
||||
|
||||
self.pd_tp_ratio = get_ascend_config().pd_tp_ratio
|
||||
self.pd_head_ratio = get_ascend_config().pd_head_ratio
|
||||
|
||||
self.first_kv_cache = None
|
||||
|
||||
def _get_prefill_decode_size(self, vllm_config: VllmConfig):
|
||||
@@ -1104,7 +1108,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
path = make_zmq_path(
|
||||
"tcp", meta.remote_host, meta.remote_port +
|
||||
self.tp_rank * self.pd_tp_ratio + offset)
|
||||
logger.debug(
|
||||
logger.info(
|
||||
f"Notify the prefiller: {path} that request: {req_id} from decoder is ready."
|
||||
)
|
||||
msg_encoder = msgspec.msgpack.Encoder()
|
||||
@@ -1142,7 +1146,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
**kwargs) -> None:
|
||||
"""MooncakeLayerwiseConnector does not save explicitly."""
|
||||
if self.kv_role == 'kv_producer':
|
||||
if self.pd_tp_ratio != 1:
|
||||
if self.pd_head_ratio != 1:
|
||||
if self.current_layer != 0:
|
||||
self.completion_event.wait()
|
||||
self.completion_event = threading.Event()
|
||||
@@ -1153,8 +1157,9 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
|
||||
def sort_kv_cache(input_kv: list[list[int]]):
|
||||
return torch.cat([
|
||||
torch.chunk(tensor, self.pd_tp_ratio, dim=0)[x]
|
||||
for x in range(self.pd_tp_ratio) for tensor in input_kv
|
||||
torch.chunk(tensor, self.pd_head_ratio, dim=0)[x]
|
||||
for x in range(self.pd_head_ratio)
|
||||
for tensor in input_kv
|
||||
])
|
||||
|
||||
total_block_ids = [
|
||||
@@ -1176,7 +1181,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
keys = sort_kv_cache(keys) # [req1_key, req2_key]
|
||||
values = sort_kv_cache(values)
|
||||
(keys,
|
||||
values) = kv_alltoall_and_rearrange(self.pd_tp_ratio, keys,
|
||||
values) = kv_alltoall_and_rearrange(self.pd_head_ratio, keys,
|
||||
values)
|
||||
key_start_id = 0
|
||||
value_start_id = 0
|
||||
@@ -1185,7 +1190,7 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
value = None
|
||||
for req_id, request in connector_metadata.requests.items():
|
||||
logger.info(f"Add request {req_id} to kv send layer thread. ")
|
||||
if self.pd_tp_ratio != 1:
|
||||
if self.pd_head_ratio != 1:
|
||||
key_block_num = len(
|
||||
request.local_block_ids) * key_block_size
|
||||
value_block_num = len(
|
||||
|
||||
Reference in New Issue
Block a user