[Bugfix][P/D] fix layerwise connector for decoder tp size > num kv heads (#5846)
### What this PR does / why we need it?
Fix layerwise connector for decoder tp size > num kv heads. In this case
prefiller should push kv cache to all decoder npu.
- vLLM version: v0.13.0
- vLLM main:
2f4e6548ef
Signed-off-by: liziyu <liziyu16@huawei.com>
This commit is contained in:
@@ -1078,10 +1078,14 @@ class MooncakeLayerwiseConnectorWorker:
|
||||
):
|
||||
# enable decode prefix cache
|
||||
if self.use_mla or self.use_sparse:
|
||||
num_kv_head = self._decode_tp_size
|
||||
num_need_send = self._decode_tp_size
|
||||
else:
|
||||
num_kv_head = self.vllm_config.model_config.hf_config.num_key_value_heads
|
||||
num_replica_groups = self.tp_size // num_kv_head if self.tp_size >= num_kv_head else 1
|
||||
if self.tp_size <= num_kv_head:
|
||||
num_need_send = self.tp_size
|
||||
else:
|
||||
num_need_send = self._decode_tp_size if self._decode_tp_size >= num_kv_head else num_kv_head
|
||||
num_replica_groups = self.tp_size // num_need_send if self.tp_size >= num_need_send else 1
|
||||
replica_group_idx = self.tp_rank % num_replica_groups
|
||||
req_ids = sorted(list(connector_metadata.requests.keys()))
|
||||
selected_req_ids = [
|
||||
|
||||
Reference in New Issue
Block a user