diff --git a/tests/ut/kv_connector/test_mooncake_connector.py b/tests/ut/kv_connector/test_mooncake_connector.py index 92305170..27f2e1e3 100644 --- a/tests/ut/kv_connector/test_mooncake_connector.py +++ b/tests/ut/kv_connector/test_mooncake_connector.py @@ -19,6 +19,20 @@ fake_engine = types.ModuleType("mooncake.engine") fake_engine.TransferEngine = MagicMock() # type: ignore[attr-defined] sys.modules["mooncake.engine"] = fake_engine +_mock_ascend_config = MagicMock(enable_kv_nz=False) +_mock_pp_group = MagicMock(rank_in_group=0, world_size=1) +_mock_tp_group = MagicMock(rank_in_group=0, world_size=4) +patch('vllm_ascend.distributed.mooncake_connector.get_pp_group', + return_value=_mock_pp_group).start() +patch('vllm_ascend.distributed.mooncake_connector.get_tp_group', + return_value=_mock_tp_group).start() +patch( + 'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_world_size', + return_value=4).start() +patch( + 'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank', + return_value=0).start() + from vllm_ascend.distributed.mooncake_connector import ( # noqa: E402 KVCacheRecvingThread, KVCacheSendingThread, KVCacheTaskTracker, KVConnectorRole, MooncakeAgentMetadata, MooncakeConnector, @@ -88,6 +102,7 @@ class TestKVCacheSendingThreadInit(unittest.TestCase): 'side_channel_host': 'localhost', 'side_channel_port': 5555, 'metadata': MagicMock(), + 'vllm_config': MockVllmConfig(), 'ready_event': threading.Event(), 'kv_caches': kv_caches, 'pcp_rank': 0 @@ -130,6 +145,7 @@ class TestGetAndClearFinishedRequests(unittest.TestCase): 'prefill_tp_size': 4, 'local_engine_id': 'engine_1', 'side_channel_host': 'localhost', + 'vllm_config': MockVllmConfig(), 'side_channel_port': 5555, 'metadata': { "test": "metadata" @@ -159,27 +175,32 @@ class TestKVCacheSendingThread(unittest.TestCase): kv_caches_base_addr=[12345678], num_blocks=2, ) + vllm_config = MockVllmConfig() host = "127.0.0.1" with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: s.bind(('', 0)) - free_port = s.getsockname()[1] + base_port = s.getsockname()[1] thread = KVCacheSendingThread(tp_rank=0, prefill_tp_size=1, local_engine_id="engine1", side_channel_host=host, - side_channel_port=free_port, + side_channel_port=base_port, metadata=metadata, + vllm_config=vllm_config, ready_event=ready_event, kv_caches={}, pcp_rank=0) thread.start() + actual_port = base_port + (thread.pp_rank * thread.tp_size + + thread.tp_rank + + thread.pcp_rank * thread.prefill_tp_size) self.assertTrue(ready_event.wait(timeout=3), "Server thread startup timeout") context = zmq.Context() # type: ignore sock = context.socket(zmq.DEALER) # type: ignore - sock.connect(f"tcp://{host}:{free_port}") + sock.connect(f"tcp://{host}:{actual_port}") encoder = msgspec.msgpack.Encoder() decoder = msgspec.msgpack.Decoder(type=MooncakeAgentMetadata) @@ -213,6 +234,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, + _prefill_pp_size=1, engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, @@ -231,7 +253,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): "remote_host": "localhost", "remote_handshake_port": 6666, "offset": 0, - "num_need_pulls": 2, + "tp_num_need_pulls": 2, "all_task_done": False } self.thread.add_request( @@ -242,7 +264,7 @@ class TestKVCacheRecvingThreadBasic(unittest.TestCase): remote_host=test_req["remote_host"], remote_handshake_port=test_req["remote_handshake_port"], offset=test_req["offset"], - num_need_pulls=test_req["num_need_pulls"], + tp_num_need_pulls=test_req["tp_num_need_pulls"], all_task_done=test_req["all_task_done"]) queued = self.thread.request_queue.get_nowait() self.assertEqual(queued["request_id"], "req1") @@ -265,6 +287,7 @@ class TestSocketManagement(unittest.TestCase): self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, + _prefill_pp_size=1, engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, @@ -315,10 +338,13 @@ class TestCoreFunctionality(unittest.TestCase): self.ready_event = threading.Event() self.mock_queue = MagicMock() self.vllm_config = MockVllmConfig() - self.kv_caches: Dict[str, Any] = {} + self.kv_caches: Dict[str, Any] = { + "layer_0": (MagicMock(), MagicMock()) + } self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, + _prefill_pp_size=1, engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, @@ -337,7 +363,7 @@ class TestCoreFunctionality(unittest.TestCase): "remote_handshake_port": 6666, "remote_transfer_port": 7777, "offset": 0, - "num_need_pulls": 2, + "tp_num_need_pulls": 2, "all_task_done": False } self.thread.task_tracker = MagicMock() @@ -362,12 +388,14 @@ class TestCoreFunctionality(unittest.TestCase): @patch.object(KVCacheRecvingThread, '_get_remote_metadata') def test_transfer_kv_cache(self, mock_get_meta): - self.thread.kv_caches_base_addr["remote_engine"] = { - 6666: [0x3000, 0x4000] - } - - self.thread._transfer_kv_cache(self.test_req) - + with patch( + 'vllm_ascend.distributed.mooncake_connector.get_ascend_config' + ) as mock_config: + mock_config.return_value.enable_kv_nz = False + self.thread.kv_caches_base_addr["remote_engine"] = { + 6666: [0x3000, 0x4000] + } + self.thread._transfer_kv_cache(self.test_req) self.engine.batch_transfer_sync_read.assert_called_once() call_args, call_kwargs = self.engine.batch_transfer_sync_read.call_args self.assertEqual(call_args[0], "localhost:7777") @@ -398,6 +426,7 @@ class TestMetadataHandling(unittest.TestCase): self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, + _prefill_pp_size=1, engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, @@ -461,6 +490,7 @@ class TestMainThreadLoop(unittest.TestCase): self.thread = KVCacheRecvingThread( tp_rank=0, tp_size=4, + _prefill_pp_size=1, engine=self.engine, local_engine_id="local_engine", local_handshake_port=5555, @@ -482,7 +512,7 @@ class TestMainThreadLoop(unittest.TestCase): "remote_handshake_port": 6666, "remote_transfer_port": 7777, "offset": 0, - "num_need_pulls": 2, + "tp_num_need_pulls": 2, "all_task_done": False } @@ -509,6 +539,10 @@ class MockVllmConfig: self.parallel_config.tensor_parallel_size = 2 self.parallel_config.data_parallel_rank = 0 self.parallel_config.data_parallel_size_local = 1 + self.parallel_config.pipeline_parallel_size = 1 + self.parallel_config.data_parallel_rank_local = 0 + self.model_config.get_num_layers_by_block_type = MagicMock( + return_value=32) self.cache_config.block_size = 16 self.kv_transfer_config.kv_port = 5000 self.kv_transfer_config.kv_role = 'kv_producer' @@ -516,11 +550,13 @@ class MockVllmConfig: self.kv_transfer_config.get_from_extra_config.side_effect = lambda k, d: { "prefill": { "tp_size": 2, - "dp_size": 1 + "dp_size": 1, + "pp_size": 1 }, "decode": { "tp_size": 2, - "dp_size": 1 + "dp_size": 1, + "pp_size": 1 } }.get(k, d) self.additional_config = {} @@ -1062,12 +1098,13 @@ class TestMooncakeConnectorWorker(unittest.TestCase): patch('torch.Tensor.element_size', return_value=4), patch('torch.Tensor.data_ptr', return_value=0x1000), patch('math.prod', return_value=128), - patch('random.Random'), patch( 'vllm_ascend.distributed.mooncake_connector.get_tensor_model_parallel_rank', mock_get_tensor_model_parallel_rank), patch('vllm_ascend.distributed.mooncake_connector.get_tp_group', mock_get_tp_group), + patch('vllm_ascend.distributed.mooncake_connector.get_pp_group', + return_value=_mock_pp_group), patch('vllm_ascend.distributed.mooncake_connector.get_ip', mock_get_ip), patch( @@ -1096,8 +1133,6 @@ class TestMooncakeConnectorWorker(unittest.TestCase): patch( 'vllm_ascend.distributed.mooncake_connector.get_decode_context_model_parallel_world_size', return_value=1), - patch('vllm_ascend.distributed.mooncake_connector.get_pcp_group', - return_value=self.mock_pcp_group), patch( 'vllm_ascend.distributed.mooncake_connector.get_ascend_config', return_value=MagicMock()), @@ -1146,6 +1181,83 @@ class TestMooncakeConnectorWorker(unittest.TestCase): # Default tp_rank is 0, so device_id should be 10 self.assertIsNotNone(worker.engine) + def test_get_remote_tp_rank(self): + + def get_tp_rank(prefill_tp_size: int, prefill_pp_size: int, + decode_tp_size: int, num_kv_heads: int, + tp_num_need_pulls: int, is_deepseek_mla: bool): + with patch('vllm_ascend.distributed.mooncake_connector.get_ascend_config', + return_value=MagicMock()), \ + patch.object(self.vllm_config.kv_transfer_config, 'get_from_extra_config', + side_effect=lambda k, d=None: { + "prefill": {"tp_size": prefill_tp_size, "dp_size": 1, "pp_size": prefill_pp_size}, + "decode": {"tp_size": decode_tp_size, "dp_size": 1, "pp_size": 1} + }.get(k, d)): + self.vllm_config.model_config.hf_config.num_key_value_heads = num_kv_heads + self.vllm_config.model_config.is_deepseek_mla = is_deepseek_mla + worker = MooncakeConnectorWorker(self.vllm_config, + self.engine_id) + worker.tp_num_need_pulls = tp_num_need_pulls + worker.use_sparse = 0 + return worker._get_remote_ranks_for_req('test') + + self.assertIn( + get_tp_rank(16, 1, 1, 4, 4, False)[0], + [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + self.assertIn( + get_tp_rank(8, 1, 1, 4, 4, False)[0], [[0, 2, 4, 6], [1, 3, 5, 7]]) + self.assertIn(get_tp_rank(4, 1, 1, 4, 4, False)[0], [[0, 1, 2, 3]]) + self.assertIn(get_tp_rank(16, 1, 4, 4, 1, False), + [[[0], [4], [8], [12]], [[1], [5], [9], [13]], + [[2], [6], [10], [14]], [[3], [7], [11], [15]]]) + self.assertIn(get_tp_rank(8, 1, 4, 4, 1, False), + [[[0], [2], [4], [6]], [[1], [3], [5], [7]]]) + self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), + [[[0, 1, 4, 5], [2, 3, 6, 7]]]) + self.assertIn(get_tp_rank(4, 1, 4, 4, 1, False), + [[[0], [1], [2], [3]]]) + self.assertIn( + get_tp_rank(8, 2, 1, 4, 4, False)[0], + [[0, 2, 4, 6, 8, 10, 12, 14], [1, 3, 5, 7, 9, 11, 13, 15]]) + self.assertIn(get_tp_rank(4, 2, 2, 4, 2, False), + [[[0, 1, 4, 5], [2, 3, 6, 7]]]) + self.assertIn(get_tp_rank(2, 2, 1, 4, 2, False), [[[0, 1, 2, 3]]]) + self.assertIn( + get_tp_rank(4, 4, 2, 8, 2, False), + [[[0, 1, 4, 5, 8, 9, 12, 13], [2, 3, 6, 7, 10, 11, 14, 15]]]) + self.assertIn( + get_tp_rank(4, 2, 1, 4, 4, False)[0], [[0, 1, 2, 3, 4, 5, 6, 7]]) + self.assertIn( + get_tp_rank(4, 4, 1, 4, 4, False)[0], + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]]) + self.assertIn(get_tp_rank(8, 2, 4, 4, 1, False), + [[[0, 8], [2, 10], [4, 12], [6, 14]], + [[1, 9], [3, 11], [5, 13], [7, 15]]]) + self.assertIn(get_tp_rank(4, 2, 4, 4, 4, False), + [[[0, 4], [1, 5], [2, 6], [3, 7]]]) + self.assertIn( + get_tp_rank(4, 4, 4, 4, 1, False), + [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + self.assertIn( + get_tp_rank(16, 1, 1, 1, 1, + True)[0], [[0], [1], [2], [3], [4], [5], [6], [7], [8], + [9], [10], [11], [12], [13], [14], [15]]) + self.assertIn(get_tp_rank(4, 1, 4, 1, 1, True), [[[0], [1], [2], [3]]]) + self.assertIn( + get_tp_rank(8, 2, 1, 1, 1, True)[0], + [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], + [7, 15]]) + self.assertIn( + get_tp_rank(4, 4, 1, 1, 1, True)[0], + [[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]) + self.assertIn( + get_tp_rank(8, 2, 4, 1, 1, True)[0], + [[0, 8], [2, 10], [4, 12], [6, 14], [1, 9], [3, 11], [5, 13], + [7, 15]]) + self.assertIn( + get_tp_rank(4, 4, 4, 1, 1, True), + [[[0, 4, 8, 12], [1, 5, 9, 13], [2, 6, 10, 14], [3, 7, 11, 15]]]) + if __name__ == '__main__': unittest.main() diff --git a/tests/ut/worker/test_model_runner_v1.py b/tests/ut/worker/test_model_runner_v1.py index 0f27548a..fe945337 100644 --- a/tests/ut/worker/test_model_runner_v1.py +++ b/tests/ut/worker/test_model_runner_v1.py @@ -22,29 +22,30 @@ from vllm_ascend.worker.model_runner_v1 import NPUModelRunner # yapf: disable @pytest.mark.parametrize( - "soc_version, enable_expert_parallel, world_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", + "soc_version, enable_expert_parallel, world_size, pipeline_size, num_tokens, mc2_tokens_capacity, quant_type, expected_method", [ # Case 1: Expert parallel is disabled, should always be 'allgather' - (AscendDeviceType._910B, False, 8, 100, 256, None, MoECommType.ALLGATHER), - (AscendDeviceType._910_93, False, 16, 500, 256, None, MoECommType.ALLGATHER), + (AscendDeviceType._910B, False, 8, 2, 100, 256, None, MoECommType.ALLGATHER), + (AscendDeviceType._910_93, False, 16, 2, 500, 256, None, MoECommType.ALLGATHER), # Case 2: A2 SOC with w4a8_dynamic -> use alltoall when not mc2 - (AscendDeviceType._910B, True, 8, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), - (AscendDeviceType._910B, True, 16, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), - (AscendDeviceType._910B, True, 16, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition + (AscendDeviceType._910B, True, 8, 1, 100, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendDeviceType._910B, True, 16, 1, 257, 256, "w4a8_dynamic", MoECommType.ALLTOALL), + (AscendDeviceType._910B, True, 16, 1, 100, 256, "w4a8_dynamic", MoECommType.MC2), # meets mc2 condition # Case 3: A2 SOC without w4a8_dynamic -> fallback to allgather - (AscendDeviceType._910B, True, 8, 100, 256, None, MoECommType.ALLGATHER), - (AscendDeviceType._910B, True, 16, 257, 256, None, MoECommType.ALLGATHER), + (AscendDeviceType._910B, True, 8, 2, 100, 256, None, MoECommType.ALLGATHER), + (AscendDeviceType._910B, True, 16, 2, 257, 256, None, MoECommType.ALLGATHER), # Case 4: A3 SOC - (AscendDeviceType._910_93, True, 8, 100, 256, None, MoECommType.MC2), - (AscendDeviceType._910_93, True, 8, 257, 256, None, MoECommType.ALLTOALL), + (AscendDeviceType._910_93, True, 8, 2, 100, 256, None, MoECommType.MC2), + (AscendDeviceType._910_93, True, 8, 2, 257, 256, None, MoECommType.ALLTOALL), ]) # yapf: enable def test_select_moe_comm_method(soc_version, enable_expert_parallel, - world_size, num_tokens, mc2_tokens_capacity, - quant_type, expected_method): + world_size, pipeline_size, num_tokens, + mc2_tokens_capacity, quant_type, + expected_method): """ Tests the _select_moe_comm_method with various configurations including quant_type. """ @@ -53,6 +54,7 @@ def test_select_moe_comm_method(soc_version, enable_expert_parallel, mock_runner.parallel_config = MagicMock() mock_runner.parallel_config.enable_expert_parallel = enable_expert_parallel mock_runner.parallel_config.world_size_across_dp = world_size + mock_runner.parallel_config.pipeline_parallel_size = pipeline_size mock_runner.mc2_tokens_capacity = mc2_tokens_capacity # Add vllm_config.model_config.hf_config mock with moe_quantize diff --git a/vllm_ascend/distributed/mooncake_connector.py b/vllm_ascend/distributed/mooncake_connector.py index 58968c1f..2414b92d 100644 --- a/vllm_ascend/distributed/mooncake_connector.py +++ b/vllm_ascend/distributed/mooncake_connector.py @@ -27,8 +27,10 @@ from vllm.distributed.kv_transfer.kv_connector.v1.base import ( KVConnectorBase_V1, KVConnectorMetadata, KVConnectorRole) from vllm.distributed.parallel_state import ( get_decode_context_model_parallel_rank, - get_decode_context_model_parallel_world_size, get_pcp_group, - get_tensor_model_parallel_rank, get_tp_group) + get_decode_context_model_parallel_world_size, get_pp_group, + get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, + get_tp_group) +from vllm.distributed.utils import get_pp_indices from vllm.logger import logger from vllm.utils.network_utils import get_ip, make_zmq_path, make_zmq_socket from vllm.v1.core.sched.output import SchedulerOutput @@ -38,6 +40,14 @@ from vllm.v1.request import RequestStatus from vllm_ascend.ascend_config import get_ascend_config, init_ascend_config from vllm_ascend.distributed.mooncake_transfer_engine import global_te from vllm_ascend.distributed.utils import get_transfer_timeout_value +from vllm_ascend.utils import prefill_context_parallel_enable + +# isort: off +if prefill_context_parallel_enable(): + from vllm.distributed import (get_prefill_context_model_parallel_rank, + get_prefill_context_model_parallel_world_size + ) +# isort: on if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata @@ -159,14 +169,17 @@ class KVCacheTaskTracker: class KVCacheSendingThread(threading.Thread): - def __init__(self, tp_rank: int, prefill_tp_size: int, - local_engine_id: str, side_channel_host: str, - side_channel_port: int, metadata: MooncakeAgentMetadata, - ready_event: threading.Event, kv_caches: dict[str, Any], - pcp_rank: int): + def __init__(self, vllm_config: VllmConfig, tp_rank: int, + prefill_tp_size: int, local_engine_id: str, + side_channel_host: str, side_channel_port: int, + metadata: MooncakeAgentMetadata, ready_event: threading.Event, + kv_caches: dict[str, Any], pcp_rank: int): super().__init__(daemon=True, name="KVCacheSendingThread") self.tp_rank = tp_rank self.prefill_tp_size = prefill_tp_size + self.pp_rank = get_pp_group().rank_in_group + self.pp_size = vllm_config.parallel_config.pipeline_parallel_size + self.tp_size = get_tensor_model_parallel_world_size() self.local_engine_id = local_engine_id self.side_channel_host = side_channel_host self.side_channel_port = side_channel_port @@ -205,8 +218,8 @@ class KVCacheSendingThread(threading.Thread): # NOTE(rob): we need each rank to have a unique port. This hack to keeps # us moving. We will switch when moving to etcd or where we have a # single ZMQ socket in the scheduler. - handshake_port = self.side_channel_port + self.pcp_rank * self.prefill_tp_size \ - + self.tp_rank + device_index = self.pp_rank * self.tp_size + self.tp_rank + self.pcp_rank * self.prefill_tp_size + handshake_port = self.side_channel_port + device_index path = make_zmq_path("tcp", self.side_channel_host, handshake_port) logger.info("Starting listening on path: %s", path) with zmq_ctx(zmq.ROUTER, path) as sock: # type: ignore @@ -258,20 +271,22 @@ class KVCacheSendingThread(threading.Thread): class KVCacheRecvingThread(threading.Thread): - def __init__(self, tp_rank: int, tp_size: int, engine: TransferEngine, - local_engine_id: str, local_handshake_port: int, + def __init__(self, tp_rank: int, tp_size: int, _prefill_pp_size: int, + engine: TransferEngine, local_engine_id: str, + local_handshake_port: int, local_kv_caches_base_addr: list[int], block_len: list[int], ready_event: threading.Event, vllm_config: VllmConfig, kv_caches: dict[str, Any]): super().__init__(daemon=True, name="KVCacheRecvingThread") self.tp_rank = tp_rank self.tp_size = tp_size - + self._prefill_pp_size = _prefill_pp_size self.local_engine_id = local_engine_id self.local_handshake_port = local_handshake_port self.engine = engine self.ready_event = ready_event + self.kv_caches = kv_caches self.kv_caches_base_addr: dict[str, dict[int, list[int]]] = \ SizedDict() self.kv_caches_base_addr[local_engine_id][local_handshake_port] = \ @@ -299,13 +314,22 @@ class KVCacheRecvingThread(threading.Thread): self.vllm_config = vllm_config self.model_config = self.vllm_config.model_config - self.num_key_value_heads = self.model_config.hf_config.num_key_value_heads - self.kv_caches = kv_caches + self.block_size = self.vllm_config.cache_config.block_size + if self.use_mla: + self.k_head_dim = self.model_config.hf_config.kv_lora_rank + self.v_head_dim = self.model_config.hf_config.qk_rope_head_dim + self.num_kv_heads = 1 + else: + self.k_head_dim = self.model_config.hf_config.head_dim + self.v_head_dim = self.model_config.hf_config.head_dim + self.num_kv_heads = max( + self.model_config.hf_config.num_key_value_heads // + self.tp_size, 1) def add_request(self, request_id: str, local_block_ids: list[int], remote_block_ids: list[int], remote_engine_id: str, remote_host: str, remote_handshake_port: int, offset: int, - num_need_pulls: int, all_task_done: bool): + tp_num_need_pulls: int, all_task_done: bool): """Add a new request to the queue for processing.""" logger.debug(f"Adding request {request_id} to the queue.") self.request_queue.put({ @@ -316,7 +340,7 @@ class KVCacheRecvingThread(threading.Thread): "remote_host": remote_host, "remote_handshake_port": remote_handshake_port, "offset": offset, - "num_need_pulls": num_need_pulls, + "tp_num_need_pulls": tp_num_need_pulls, "all_task_done": all_task_done }) @@ -376,7 +400,7 @@ class KVCacheRecvingThread(threading.Thread): remote_host = req_meta["remote_host"] remote_handshake_port = req_meta["remote_handshake_port"] offset = req_meta["offset"] - self.num_need_pulls = req_meta["num_need_pulls"] + tp_num_need_pulls = req_meta["tp_num_need_pulls"] # Full prefix cache hit: do not need to read remote blocks, just notify # P worker that we have the blocks we need. @@ -394,7 +418,7 @@ class KVCacheRecvingThread(threading.Thread): remote_handshake_port not in self.kv_caches_base_addr[remote_engine_id]: self._get_remote_metadata(remote_host, remote_handshake_port) - if self.num_need_pulls == 1: + if tp_num_need_pulls == 1: grouped_remote_block_ids, grouped_local_block_ids = \ group_concurrent_contiguous(remote_block_ids, local_block_ids) else: @@ -402,11 +426,25 @@ class KVCacheRecvingThread(threading.Thread): local_block_ids = list(map(lambda x: [x], local_block_ids)) grouped_remote_block_ids, grouped_local_block_ids = remote_block_ids, local_block_ids num_transfer_groups = len(grouped_remote_block_ids) + # tp_num_need_pulls: number of KV caches each Decode node needs to pull from each PP stage + # Due to GQA, different KV heads are distributed across different ranks, so there are offsets + # indicating which KV head to pull + global_offset = offset # Global offset of request across all ranks + prefill_pp_rank = offset // tp_num_need_pulls # PP rank where current request resides + inner_offset = offset % tp_num_need_pulls # Offset within each PP stage remote_kv_caches_base_addrs = \ self.kv_caches_base_addr[remote_engine_id][remote_handshake_port] + num_layers = self.model_config.hf_config.num_hidden_layers + first_layer_index, end_layer_index = get_pp_indices( + num_layers, prefill_pp_rank, self._prefill_pp_size) + num_cache_per_layer = len(list( + self.kv_caches.values())[0]) # Number of KV caches per layer local_kv_caches_base_addrs = \ - self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port] + self.kv_caches_base_addr[self.local_engine_id][self.local_handshake_port][first_layer_index*num_cache_per_layer : end_layer_index*num_cache_per_layer] + logger.debug( + f"transfer kv cache first_layer_index:{first_layer_index} , end_layer_index:{end_layer_index}" + ) remote_transfer_port = self.remote_te_port[remote_engine_id][ remote_handshake_port] num_blocks = len(local_block_ids) @@ -422,11 +460,11 @@ class KVCacheRecvingThread(threading.Thread): block_len = (self.block_len[k % 3]) else: block_len = (self.block_len[0]) - inner_block_len = block_len // self.num_need_pulls + inner_block_len = block_len // tp_num_need_pulls for remote_block_id, local_block_id in zip( grouped_remote_block_ids, grouped_local_block_ids): src = src_layer_base_addr + local_block_id[ - 0] * block_len + offset * inner_block_len + 0] * block_len + inner_offset * inner_block_len dst = dst_layer_base_addr + remote_block_id[0] * inner_block_len length = inner_block_len * len(local_block_id) src_list.append(src) @@ -447,10 +485,17 @@ class KVCacheRecvingThread(threading.Thread): " %d blocks). local_ip %s local_device_id %s remote_session_id %s", request_id, req_transfer_elapsed, num_transfer_groups, num_blocks, get_ip(), self.tp_rank, session_id) - if self.num_need_pulls > 1 and offset == self.num_need_pulls - 1: - self._cat_kv_cache(grouped_local_block_ids) - def _cat_kv_cache(self, block_ids: list[list[int]]): + # Determine if the current position is the offset position at the end of the KV transmission. + is_kv_transfer_end = ( + global_offset == tp_num_need_pulls * self._prefill_pp_size - 1) + need_cat_cache = tp_num_need_pulls > 1 and is_kv_transfer_end + # need_nz_cache maybe caused error in non-MLA models + if need_cat_cache: + self._cat_kv_cache(grouped_local_block_ids, tp_num_need_pulls) + + def _cat_kv_cache(self, block_ids: list[list[int]], + tp_num_need_pulls: int): # Get necessary parameters k_cache = list(self.kv_caches.values())[0][0] dtype = k_cache.dtype @@ -506,9 +551,11 @@ class KVCacheRecvingThread(threading.Thread): # Transpose KV cache k_buffer = self._transpose_kv_cache_between_head( - k_buffer, num_blocks, block_size, block_len, num_kv_head) + k_buffer, num_blocks, block_size, block_len, num_kv_head, + tp_num_need_pulls) v_buffer = self._transpose_kv_cache_between_head( - v_buffer, num_blocks, block_size, block_len, num_kv_head) + v_buffer, num_blocks, block_size, block_len, num_kv_head, + tp_num_need_pulls) # Reshape and cache the processed buffers torch_npu._npu_reshape_and_cache( @@ -522,11 +569,11 @@ class KVCacheRecvingThread(threading.Thread): # Clean up buffers del k_buffer, v_buffer - def _transpose_kv_cache_between_head(self, buffer: torch.Tensor, - num_blocks: int, block_size: int, - block_len: int, - num_kv_head: int) -> torch.Tensor: - buffer = buffer.view(num_blocks, self.num_need_pulls, block_size, -1) + def _transpose_kv_cache_between_head( + self, buffer: torch.Tensor, num_blocks: int, block_size: int, + block_len: int, num_kv_head: int, + tp_num_need_pulls: int) -> torch.Tensor: + buffer = buffer.view(num_blocks, tp_num_need_pulls, block_size, -1) buffer.transpose_(1, 2) return buffer.contiguous().view(block_len, num_kv_head, -1) @@ -631,8 +678,8 @@ class MooncakeConnectorMetadata(KVConnectorMetadata): remote_engine_id=kv_transfer_params["remote_engine_id"], remote_host=kv_transfer_params["remote_host"], remote_port=kv_transfer_params["remote_port"], - remote_pcp_size=kv_transfer_params["remote_pcp_size"], - remote_dcp_size=kv_transfer_params["remote_dcp_size"], + remote_pcp_size=kv_transfer_params.get("remote_pcp_size", 1), + remote_dcp_size=kv_transfer_params.get("remote_dcp_size", 1), ) @@ -736,13 +783,17 @@ class MooncakeConnectorScheduler: self.side_channel_host = get_ip() self.pcp_size = vllm_config.parallel_config.prefill_context_parallel_size self.dcp_size = vllm_config.parallel_config.decode_context_parallel_size + self.max_device_id = vllm_config.parallel_config.tensor_parallel_size * \ + vllm_config.parallel_config.data_parallel_size * \ + self.pcp_size * \ + vllm_config.parallel_config.pipeline_parallel_size # Handshake base port self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) - + vllm_config.parallel_config.tensor_parallel_size * + vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size) # Requests that need to start recv. # New requests are added by update_state_after_alloc in # the scheduler. Used to make metadata passed to Worker. @@ -894,15 +945,24 @@ class MooncakeConnectorWorker: self.tp_rank = get_tensor_model_parallel_rank() self.tp_size = vllm_config.parallel_config.tensor_parallel_size self.tp_group = get_tp_group() + self.pp_rank = get_pp_group().rank_in_group + self.dp_rank = vllm_config.parallel_config.data_parallel_rank_local + self.dp_size = vllm_config.parallel_config.data_parallel_size_local + self.pp_size = vllm_config.parallel_config.pipeline_parallel_size self.kv_caches: dict[str, torch.Tensor] = {} self.side_channel_host = get_ip() - self.pcp_size = get_pcp_group().world_size - self.pcp_rank = get_pcp_group( - ).rank_in_group if self.pcp_size > 1 else 0 + self.pcp_size = get_prefill_context_model_parallel_world_size( + ) if prefill_context_parallel_enable() else 1 + # Assert that pp_size and pcp_size cannot both be greater than 1 + assert not (self.pp_size > 1 and self.pcp_size + > 1), "pp and pcp cannot open in same time" + self.pcp_rank = get_prefill_context_model_parallel_rank( + ) if self.pcp_size > 1 else 0 self.dcp_size = get_decode_context_model_parallel_world_size() self.dcp_rank = get_decode_context_model_parallel_rank( ) if self.dcp_size > 1 else 0 + self.max_device_id = self.tp_size * self.dp_size * self.pcp_size * self.pp_size self.kv_role = vllm_config.kv_transfer_config.kv_role self.num_key_value_heads = self.vllm_config.model_config.hf_config.num_key_value_heads @@ -910,10 +970,12 @@ class MooncakeConnectorWorker: self.side_channel_port = ( vllm_config.kv_transfer_config.kv_port + vllm_config.parallel_config.data_parallel_rank * - vllm_config.parallel_config.tensor_parallel_size * self.pcp_size) - self.handshake_port = self.side_channel_port + self.pcp_rank * self.tp_size + self.tp_rank + vllm_config.parallel_config.tensor_parallel_size * + vllm_config.parallel_config.pipeline_parallel_size * self.pcp_size) + device_index = (self.pp_rank + + self.pcp_rank) * self.tp_size + self.tp_rank + self.handshake_port = self.side_channel_port + device_index self.sockets: dict = {} - logger.info("Initializing Mooncake work %s", engine_id) self.engine = global_te.get_transfer_engine(self.side_channel_host, device_name=None) self.te_rpc_port = self.engine.get_rpc_port() @@ -926,13 +988,13 @@ class MooncakeConnectorWorker: self.vllm_config = vllm_config self.block_size = vllm_config.cache_config.block_size if self.vllm_config.model_config.is_deepseek_mla: - self.num_need_pulls = 1 + self.tp_num_need_pulls = 1 else: num_d_block_heads = max(1, self.num_key_value_heads // self.tp_size) num_p_block_heads = max( 1, self.num_key_value_heads // self._prefill_tp_size) - self.num_need_pulls = num_d_block_heads // num_p_block_heads + self.tp_num_need_pulls = num_d_block_heads // num_p_block_heads def _get_prefill_decode_size(self, vllm_config: VllmConfig): # get prefill tp and dp size from extra config @@ -945,7 +1007,8 @@ class MooncakeConnectorWorker: assert "dp_size" in prefill_parallel_config.keys() self._prefill_dp_size = prefill_parallel_config["dp_size"] - + # get prefill pp size from extra config + self._prefill_pp_size = prefill_parallel_config.get("pp_size", 1) # get decode tp and dp size from extra config decode_parallel_config: dict[ str, Any] = vllm_config.kv_transfer_config.get_from_extra_config( @@ -954,6 +1017,9 @@ class MooncakeConnectorWorker: self._decode_tp_size = decode_parallel_config["tp_size"] assert "dp_size" in decode_parallel_config.keys() self._decode_dp_size = decode_parallel_config["dp_size"] + # get prefill pp size from extra config + self._decode_pp_size = decode_parallel_config.get("pp_size", 1) + assert self._decode_pp_size == 1, "decode pp size must be 1" def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): """Register the KV Cache data.""" @@ -1052,15 +1118,15 @@ class MooncakeConnectorWorker: ready_event = threading.Event() if self.kv_role == 'kv_producer': self.kv_send_thread = KVCacheSendingThread( - self.tp_rank, self._prefill_tp_size, self.engine_id, - self.side_channel_host, self.side_channel_port, metadata, - ready_event, self.kv_caches, self.pcp_rank) + self.vllm_config, self.tp_rank, self._prefill_tp_size, + self.engine_id, self.side_channel_host, self.side_channel_port, + metadata, ready_event, self.kv_caches, self.pcp_rank) self.kv_send_thread.start() else: self.kv_recv_thread = KVCacheRecvingThread( - self.tp_rank, self.tp_size, self.engine, self.engine_id, - self.handshake_port, kv_caches_base_addr, self.block_len, - ready_event, self.vllm_config, self.kv_caches) + self.tp_rank, self.tp_size, self._prefill_pp_size, self.engine, + self.engine_id, self.handshake_port, kv_caches_base_addr, + self.block_len, ready_event, self.vllm_config, self.kv_caches) self.kv_recv_thread.start() ready_event.wait() @@ -1089,7 +1155,7 @@ class MooncakeConnectorWorker: Use this function to calculate remote port and remote block number of each remote P node that we need to pull. """ if meta.remote_pcp_size * meta.remote_dcp_size * self.pcp_size * self.dcp_size == 1: - choosen_rank_list = self._get_remote_tp_rank(req_id) + choosen_rank_list = self._get_remote_rank(req_id) remote_handshake_port_list = [[ x + meta.remote_port for x in choosen_rank_list ]] @@ -1174,77 +1240,121 @@ class MooncakeConnectorWorker: meta.remote_engine_id, len(meta.local_block_ids), len(meta.remote_block_ids)) - remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( - req_id, meta) + if prefill_context_parallel_enable(): + remote_handshake_port_list, local_block_ids_list, remote_block_ids_list = self._get_kv_split_metadata( + req_id, meta) - for pcp_dcp_rank in range(len(remote_handshake_port_list)): - if len(local_block_ids_list[pcp_dcp_rank]) + len( - remote_block_ids_list[pcp_dcp_rank]) == 0: - continue - for i in range(self.num_need_pulls): + for pcp_dcp_rank in range(len(remote_handshake_port_list)): + if len(local_block_ids_list[pcp_dcp_rank]) + len( + remote_block_ids_list[pcp_dcp_rank]) == 0: + continue + for i in range(self.tp_num_need_pulls): + assert self.kv_recv_thread is not None + self.kv_recv_thread.add_request( + request_id=req_id, + local_block_ids=local_block_ids_list[pcp_dcp_rank], + remote_block_ids=remote_block_ids_list[ + pcp_dcp_rank], + remote_engine_id=meta.remote_engine_id, + remote_host=meta.remote_host, + remote_handshake_port=remote_handshake_port_list[ + pcp_dcp_rank][i], + offset=i, + tp_num_need_pulls=self.tp_num_need_pulls, + all_task_done=( + pcp_dcp_rank + == len(remote_handshake_port_list) - 1 + and i == self.tp_num_need_pulls - 1)) + else: #TODO: support prefill context parallel and pipeline parallel open at the same time + choosen_rank_list = self._get_remote_rank(req_id) + remote_handshake_port_list = [[x + meta.remote_port] + for x in choosen_rank_list] + for i in range(self.tp_num_need_pulls * self._prefill_pp_size): assert self.kv_recv_thread is not None self.kv_recv_thread.add_request( request_id=req_id, - local_block_ids=local_block_ids_list[pcp_dcp_rank], - remote_block_ids=remote_block_ids_list[pcp_dcp_rank], + local_block_ids=meta.local_block_ids, + remote_block_ids=meta.remote_block_ids, remote_engine_id=meta.remote_engine_id, remote_host=meta.remote_host, - remote_handshake_port=remote_handshake_port_list[ - pcp_dcp_rank][i], + remote_handshake_port=remote_handshake_port_list[i][0], offset=i, - num_need_pulls=self.num_need_pulls, - all_task_done=(pcp_dcp_rank - == len(remote_handshake_port_list) - 1 - and i == self.num_need_pulls - 1)) + tp_num_need_pulls=self.tp_num_need_pulls, + all_task_done=(i == self.tp_num_need_pulls * + self._prefill_pp_size - 1)) if self.kv_send_thread is not None: for req_id, delay_start_time in metadata.requests_to_send.items(): - if self.tp_rank in self._prefill_get_remote_tp_rank(req_id): + if self.tp_rank in self._prefill_get_remote_rank(req_id): self.kv_send_thread.add_delayed_request( req_id, delay_start_time) else: self.kv_send_thread.add_not_transfer_request(req_id) - def _prefill_get_remote_tp_rank(self, req_id: str) -> List[int]: - return sum(self._get_remote_tp_ranks_for_req(req_id), []) + def _prefill_get_remote_rank(self, req_id: str) -> List[int]: + return sum(self._get_remote_ranks_for_req(req_id), []) - def _get_remote_tp_rank(self, req_id: str) -> List[int]: - return self._get_remote_tp_ranks_for_req(req_id)[self.tp_rank] + def _get_remote_rank(self, req_id: str) -> List[int]: + return self._get_remote_ranks_for_req(req_id)[self.tp_rank] - def _get_remote_tp_ranks_for_req(self, req_id: str) -> List[List[int]]: - if self._prefill_tp_size == self._decode_tp_size: - result = list(map(lambda x: [x], range(self._prefill_tp_size))) - return result - - seed = string_to_int64_hash(req_id) - rand = random.Random(seed) - sampled_nums = [] - ori_data = np.arange(self._prefill_tp_size) + def _get_remote_tp_ranks(self, tp_ori_data: np.ndarray, + rand_group_index: list[int], + num_groups: int) -> List[List[int]]: # random split prefill tp list + tp_sampled_nums = [] if self._prefill_tp_size > self.num_key_value_heads or self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: - # use deepseek mla, num_key_value_heads == 128, but consider as 1 - if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: - num_kv_head = 1 - else: - num_kv_head = self.num_key_value_heads - num_groups = len(ori_data) // num_kv_head - ori_data = ori_data.reshape(-1, num_groups) - rand_group_index = rand.sample(range(num_groups), \ - max(self._decode_tp_size // num_kv_head, 1)) # random choose a group - - choosen_group = ori_data[:, [rand_group_index]] + tp_ori_data = tp_ori_data.reshape(-1, num_groups) + choosen_group = tp_ori_data[:, [rand_group_index]] flattened = choosen_group.reshape(-1).tolist() - sampled_nums = [ - flattened[i:i + self.num_need_pulls] - for i in range(0, len(flattened), self.num_need_pulls) + tp_sampled_nums = [ + flattened[i:i + self.tp_num_need_pulls] + for i in range(0, len(flattened), self.tp_num_need_pulls) ] - # non-random split else: group_size = self._prefill_tp_size // self._decode_tp_size for i in range(self._decode_tp_size): - ori_data_slice = ori_data[i * group_size:(i + 1) * group_size] - sampled_nums.append(ori_data_slice.tolist()) + slice = tp_ori_data[i * group_size:(i + 1) * group_size] + tp_sampled_nums.append(slice.tolist()) + return tp_sampled_nums + + def _get_remote_ranks_for_req(self, req_id: str) -> List[List[int]]: + # Divide the ports according to the TP within the PP + sampled_nums = [] + if self._prefill_tp_size == self._decode_tp_size: + sampled_nums = list( + map( + lambda tp: [ + tp + pp * self._prefill_tp_size + for pp in range(self._prefill_pp_size) + ], range(self._prefill_tp_size))) + return sampled_nums + # use deepseek mla, num_key_value_heads == 128, but consider as 1 + if self.vllm_config.model_config.is_deepseek_mla or self.use_sparse: + num_kv_head = 1 + else: + num_kv_head = self.num_key_value_heads + ori_data = np.arange(self._prefill_tp_size * self._prefill_pp_size) + seed = string_to_int64_hash(req_id) + rand = random.Random(seed) + # random split prefill tp list + ori_data = ori_data.reshape(self._prefill_pp_size, -1) + num_groups = max( + 1, + len(ori_data[0]) // num_kv_head + ) # The number of redundant copies for each KV head within the PP stage + rand_group_index = rand.sample(range(num_groups), \ + (max(self._decode_tp_size // num_kv_head, 1))) # random choose a group + all_results = [ + self._get_remote_tp_ranks(ori_data[pp_index], rand_group_index, + num_groups) + for pp_index in range(self._prefill_pp_size) + ] + for group_index in range(len(all_results[0])): + group = [] + for pp_index in range(self._prefill_pp_size): + group.extend(all_results[pp_index][group_index]) + sampled_nums.append(group) return sampled_nums diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index d903bcb1..ca9095a8 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -46,7 +46,8 @@ from vllm.compilation.counter import compilation_counter from vllm.compilation.monitor import set_cudagraph_capturing_enabled from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) -from vllm.distributed import tensor_model_parallel_all_gather +from vllm.distributed import (get_tensor_model_parallel_world_size, + tensor_model_parallel_all_gather) from vllm.distributed.ec_transfer import get_ec_transfer, has_ec_transfer from vllm.distributed.kv_transfer import (get_kv_transfer_group, has_kv_transfer_group) @@ -1765,11 +1766,22 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): else: assert intermediate_tensors is not None assert self.intermediate_tensors is not None + # If both flashcomm1 and pp are used simultaneously, + # the shape of the received data and the shape of the space to be copied to will not match, + # requiring a recalculation of the incoming data's shape. + tp_size = get_tensor_model_parallel_world_size() + num_input_tokens_with_flashcomm1 = num_input_tokens + if enable_sp(): + num_input_tokens_with_flashcomm1 = (num_input_tokens + + tp_size - 1) // tp_size for k, v in intermediate_tensors.items(): - self.intermediate_tensors[k][:num_input_tokens].copy_( - v[:num_input_tokens], non_blocking=True) + self.intermediate_tensors[ + k][:num_input_tokens_with_flashcomm1].copy_( + v[:num_input_tokens_with_flashcomm1], + non_blocking=True) intermediate_tensors = IntermediateTensors({ - k: v[:num_input_tokens] + k: + v[:num_input_tokens_with_flashcomm1] for k, v in self.intermediate_tensors.items() }) @@ -2044,7 +2056,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): update_attn_params(self.update_stream, forward_context, maybe_padded_num_tokens) - if get_forward_context().sp_enabled: + if get_forward_context().sp_enabled and not isinstance( + hidden_states, IntermediateTensors): hidden_states = tensor_model_parallel_all_gather(hidden_states, 0) pad_size = get_forward_context().pad_size if pad_size > 0: @@ -2366,7 +2379,8 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): moe_comm_type = MoECommType.ALLGATHER elif soc_version in {AscendDeviceType._910B}: if (num_tokens <= self.mc2_tokens_capacity - and self.parallel_config.world_size_across_dp >= 16): + and self.parallel_config.world_size_across_dp / + self.parallel_config.pipeline_parallel_size >= 16): moe_comm_type = MoECommType.MC2 else: # Currently, w4a8_dynamic does not support allgatherep @@ -3131,10 +3145,16 @@ class NPUModelRunner(LoRAModelRunnerMixin, ECConnectorModelRunnerMixin): if get_pp_group().is_first_rank: intermediate_tensors = None else: + # When PP and flashcomm1 are enabled, during dummy_run the estimated space should divide num_tokens by tp_size; + # otherwise, on non-first PP ranks it would effectively perform an extra all-gather, leading to incorrect memory estimation and potentially causing OOM. + actual_tokens = num_tokens + if enable_sp(): + tp_size = get_tensor_model_parallel_world_size() + actual_tokens = num_tokens // tp_size if self.intermediate_tensors is None: self.intermediate_tensors = ( self.model.make_empty_intermediate_tensors( - batch_size=num_tokens, + batch_size=actual_tokens, dtype=self.dtype, device=self.device)) intermediate_tensors = IntermediateTensors({ diff --git a/vllm_ascend/worker/worker_v1.py b/vllm_ascend/worker/worker_v1.py index 2cb574d6..265e5211 100644 --- a/vllm_ascend/worker/worker_v1.py +++ b/vllm_ascend/worker/worker_v1.py @@ -52,9 +52,9 @@ from vllm_ascend.device_allocator.camem import CaMemAllocator from vllm_ascend.distributed.parallel_state import init_ascend_model_parallel from vllm_ascend.ops.triton.triton_utils import init_device_properties_triton from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import (check_ascend_device_type, is_enable_nz, - register_ascend_customop, sleep_mode_enabled, - try_register_lib) +from vllm_ascend.utils import (check_ascend_device_type, enable_sp, + is_enable_nz, register_ascend_customop, + sleep_mode_enabled, try_register_lib) from vllm_ascend.worker.model_runner_v1 import NPUModelRunner torch._dynamo.trace_rules.clear_lru_cache() # noqa: E402 @@ -296,9 +296,14 @@ class NPUWorker(WorkerBase): intermediate_tensors = None forward_pass = scheduler_output.total_num_scheduled_tokens > 0 if forward_pass and not get_pp_group().is_first_rank: + # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1. + if enable_sp(): + all_gather_group = None + else: + all_gather_group = get_tp_group() intermediate_tensors = IntermediateTensors( get_pp_group().recv_tensor_dict( - all_gather_group=get_tp_group())) + all_gather_group=all_gather_group)) output = self.model_runner.execute_model(scheduler_output, intermediate_tensors) @@ -309,9 +314,13 @@ class NPUWorker(WorkerBase): parallel_config = self.vllm_config.parallel_config assert parallel_config.distributed_executor_backend != ( "external_launcher") and not get_pp_group().is_last_rank - + # If flashcomm1 is used, this all_gather_group parameter needs to be removed, otherwise it will conflict with the all-gather operation in flashcomm1. + if enable_sp(): + all_gather_group = None + else: + all_gather_group = get_tp_group() get_pp_group().send_tensor_dict(output.tensors, - all_gather_group=get_tp_group()) + all_gather_group=all_gather_group) kv_connector_output = output.kv_connector_output if not kv_connector_output: