[BugFix] Fix mooncake bug in PCP scenario (#5055)
### What this PR does / why we need it?
The mooncake_connector.py file was importing the wrong arguments to the
file, which could cause errors when use PCP; this issue has been
corrected.
### Does this PR introduce _any_ user-facing change?
NO
### How was this patch tested?
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: daishixun <dsxsteven@sina.com>
This commit is contained in:
@@ -22,6 +22,7 @@ sys.modules["mooncake.engine"] = fake_engine
|
|||||||
_mock_ascend_config = MagicMock(enable_kv_nz=False)
|
_mock_ascend_config = MagicMock(enable_kv_nz=False)
|
||||||
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
|
_mock_pp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||||
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
|
_mock_tp_group = MagicMock(rank_in_group=0, world_size=4)
|
||||||
|
_mock_pcp_group = MagicMock(rank_in_group=0, world_size=1)
|
||||||
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
|
patch('vllm_ascend.distributed.mooncake_connector.get_pp_group',
|
||||||
return_value=_mock_pp_group).start()
|
return_value=_mock_pp_group).start()
|
||||||
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
patch('vllm_ascend.distributed.mooncake_connector.get_tp_group',
|
||||||
@@ -32,6 +33,8 @@ patch(
|
|||||||
patch(
|
patch(
|
||||||
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
|
'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank',
|
||||||
return_value=0).start()
|
return_value=0).start()
|
||||||
|
patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group',
|
||||||
|
return_value=_mock_pcp_group).start()
|
||||||
|
|
||||||
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402
|
||||||
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker,
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import zmq
|
|||||||
from mooncake.engine import TransferEngine # type: ignore
|
from mooncake.engine import TransferEngine # type: ignore
|
||||||
from vllm import envs
|
from vllm import envs
|
||||||
from vllm.config import VllmConfig
|
from vllm.config import VllmConfig
|
||||||
|
from vllm.distributed import get_pcp_group
|
||||||
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
from vllm.distributed.kv_transfer.kv_connector.v1.base import (
|
||||||
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
KVConnectorBase_V1, KVConnectorHandshakeMetadata, KVConnectorMetadata,
|
||||||
KVConnectorRole)
|
KVConnectorRole)
|
||||||
@@ -43,13 +44,6 @@ from vllm_ascend.distributed.mooncake_transfer_engine import global_te
|
|||||||
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
from vllm_ascend.distributed.utils import get_transfer_timeout_value
|
||||||
from vllm_ascend.utils import prefill_context_parallel_enable
|
from vllm_ascend.utils import prefill_context_parallel_enable
|
||||||
|
|
||||||
# isort: off
|
|
||||||
if prefill_context_parallel_enable():
|
|
||||||
from vllm.distributed import (get_prefill_context_model_parallel_rank,
|
|
||||||
get_prefill_context_model_parallel_world_size
|
|
||||||
)
|
|
||||||
# isort: on
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from vllm.attention.backends.abstract import AttentionMetadata
|
from vllm.attention.backends.abstract import AttentionMetadata
|
||||||
from vllm.forward_context import ForwardContext
|
from vllm.forward_context import ForwardContext
|
||||||
@@ -1003,13 +997,12 @@ class MooncakeConnectorWorker:
|
|||||||
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
self.pp_size = vllm_config.parallel_config.pipeline_parallel_size
|
||||||
self.kv_caches: dict[str, torch.Tensor] = {}
|
self.kv_caches: dict[str, torch.Tensor] = {}
|
||||||
self.side_channel_host = get_ip()
|
self.side_channel_host = get_ip()
|
||||||
self.pcp_size = get_prefill_context_model_parallel_world_size(
|
self.pcp_size = get_pcp_group().world_size
|
||||||
) if prefill_context_parallel_enable() else 1
|
|
||||||
# Assert that pp_size and pcp_size cannot both be greater than 1
|
# Assert that pp_size and pcp_size cannot both be greater than 1
|
||||||
assert not (self.pp_size > 1 and self.pcp_size
|
assert not (self.pp_size > 1 and self.pcp_size
|
||||||
> 1), "pp and pcp cannot open in same time"
|
> 1), "pp and pcp cannot open in same time"
|
||||||
self.pcp_rank = get_prefill_context_model_parallel_rank(
|
self.pcp_rank = get_pcp_group(
|
||||||
) if self.pcp_size > 1 else 0
|
).rank_in_group if self.pcp_size > 1 else 0
|
||||||
self.dcp_size = get_decode_context_model_parallel_world_size()
|
self.dcp_size = get_decode_context_model_parallel_world_size()
|
||||||
self.dcp_rank = get_decode_context_model_parallel_rank(
|
self.dcp_rank = get_decode_context_model_parallel_rank(
|
||||||
) if self.dcp_size > 1 else 0
|
) if self.dcp_size > 1 else 0
|
||||||
|
|||||||
Reference in New Issue
Block a user