From 71df17f4e67568e24e248850afd14948eb0d43b6 Mon Sep 17 00:00:00 2001 From: Qiu Date: Mon, 23 Mar 2026 15:44:21 +0800 Subject: [PATCH] bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291) ### What this PR does / why we need it? This PR refactors the communication group of MC2 to keep it consistent with vllm's EP group, making it compatible with PP. - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: QiuChunshuo --- .github/workflows/scripts/config.yaml | 4 +- tests/e2e/conftest.py | 474 +++++++++++++++++- .../2-cards/test_pipeline_parallel.py | 49 -- .../4-cards/test_pipeline_parallel.py | 88 ++++ vllm_ascend/distributed/parallel_state.py | 45 +- 5 files changed, 571 insertions(+), 89 deletions(-) delete mode 100644 tests/e2e/multicard/2-cards/test_pipeline_parallel.py create mode 100644 tests/e2e/multicard/4-cards/test_pipeline_parallel.py diff --git a/.github/workflows/scripts/config.yaml b/.github/workflows/scripts/config.yaml index c99d36a5..eef2e817 100644 --- a/.github/workflows/scripts/config.yaml +++ b/.github/workflows/scripts/config.yaml @@ -127,8 +127,6 @@ e2e-multicard-2-cards: estimated_time: 180 - name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2 estimated_time: 202 -- name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py - estimated_time: 357 - name: tests/e2e/multicard/2-cards/test_prefix_caching.py estimated_time: 470 - name: tests/e2e/multicard/2-cards/test_quantization.py @@ -165,3 +163,5 @@ e2e-multicard-4-cards: is_skipped: true - name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py estimated_time: 1340 +- name: tests/e2e/multicard/4-cards/test_pipeline_parallel.py + estimated_time: 357 diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index b1c47086..29c09601 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -453,7 +453,6 @@ class RemoteEPDServer(RemoteOpenAIServer): self.env_dict.update(env_dict) self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1" - self.env_dict["VLLM_USE_V1"] = "1" self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True" self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" @@ -626,6 +625,126 @@ class DisaggEpdProxy(RemoteEPDServer): super()._terminate_server() +_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0 +_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0 +_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0 + + +def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]: + if num_items < 0: + raise ValueError("num_items must be non-negative") + if dp_size <= 0: + raise ValueError("dp_size must be positive") + + floor = num_items // dp_size + remainder = num_items % dp_size + + def start(rank: int) -> int: + return rank * floor + min(rank, remainder) + + return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)] + + +def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]): + if inputs is None: + return None + return [inputs[index] for index in indices] + + +def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]: + return [items[index] for index in indices] + + +def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]: + merged: list[Any] = [None] * total_items + for indices, results in shard_results: + if not indices: + continue + if len(indices) != len(results): + raise RuntimeError("Mismatched result count returned by data parallel worker") + for index, result in zip(indices, results): + merged[index] = result + + if any(result is None for result in merged): + raise RuntimeError("Some data parallel results were not returned") + + return merged + + +def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]: + if isinstance(text_1, str) and isinstance(text_2, str): + return [text_1], [text_2] + if isinstance(text_1, str): + return [text_1] * len(text_2), list(text_2) + if isinstance(text_2, str): + return list(text_1), [text_2] * len(text_1) + if len(text_1) != len(text_2): + raise ValueError("`text_1` and `text_2` must have the same length") + return list(text_1), list(text_2) + + +def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None: + llm = None + try: + os.environ["VLLM_DP_RANK"] = str(dp_rank) + os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank) + os.environ["VLLM_DP_SIZE"] = str(dp_size) + os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1" + os.environ["VLLM_DP_MASTER_PORT"] = str(master_port) + os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn" + + llm = LLM(**llm_kwargs) + conn.send({"status": "ready", "rank": dp_rank}) + + while True: + request = conn.recv() + command = request["command"] + if command == "shutdown": + break + + result: Any + if command == "generate": + req_outputs = llm.generate( + request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"] + ) + result = VllmRunner._finalize_generate_outputs(req_outputs) + elif command == "generate_w_logprobs": + req_outputs = llm.generate( + request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"] + ) + result = VllmRunner._final_steps_generate_w_logprobs(req_outputs) + elif command == "classify": + req_outputs = llm.classify(request["prompts"]) + result = [req_output.outputs.probs for req_output in req_outputs] + elif command == "embed": + req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"]) + result = [req_output.outputs.embedding for req_output in req_outputs] + elif command == "encode": + req_outputs = llm.encode(request["prompts"]) + result = [req_output.outputs.data for req_output in req_outputs] + elif command == "reward": + req_outputs = llm.reward(request["prompts"]) + result = [req_output.outputs.data for req_output in req_outputs] + elif command == "score": + req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"]) + result = [req_output.outputs.score for req_output in req_outputs] + else: + raise ValueError(f"Unsupported data parallel command: {command}") + + conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result}) + except Exception: + with contextlib.suppress(Exception): + conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()}) + raise + finally: + if llm is not None: + del llm + clear_ascend_config() + cleanup_dist_env_and_memory() + with contextlib.suppress(Exception): + conn.close() + + class VllmRunner: def __init__( self, @@ -645,6 +764,10 @@ class VllmRunner: quantization: str | None = None, **kwargs, ) -> None: + data_parallel_size = int(kwargs.get("data_parallel_size", 1)) + if data_parallel_size > 1: + raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.") + self.model = LLM( model=model_name, runner=runner, @@ -664,6 +787,22 @@ class VllmRunner: **kwargs, ) + @staticmethod + def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]: + outputs: list[tuple[list[list[int]], list[str]]] = [] + for req_output in req_outputs: + prompt_str = req_output.prompt + prompt_ids = req_output.prompt_token_ids + req_sample_output_ids: list[list[int]] = [] + req_sample_output_strs: list[str] = [] + for sample in req_output.outputs: + output_str = sample.text + output_ids = list(sample.token_ids) + req_sample_output_ids.append(prompt_ids + output_ids) + req_sample_output_strs.append((prompt_str or "") + output_str) + outputs.append((req_sample_output_ids, req_sample_output_strs)) + return outputs + def get_inputs( self, prompts: list[str] | list[torch.Tensor] | list[int], @@ -698,7 +837,7 @@ class VllmRunner: def generate( self, - prompts: list[str] | list[torch.Tensor], + prompts: list[str] | list[torch.Tensor] | list[list[int]], sampling_params: SamplingParams, images: PromptImageInput | None = None, videos: PromptVideoInput | None = None, @@ -706,22 +845,8 @@ class VllmRunner: **kwargs: Any, ) -> list[tuple[list[list[int]], list[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) - req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs) - - outputs: list[tuple[list[list[int]], list[str]]] = [] - for req_output in req_outputs: - prompt_str = req_output.prompt - prompt_ids = req_output.prompt_token_ids - req_sample_output_ids: list[list[int]] = [] - req_sample_output_strs: list[str] = [] - for sample in req_output.outputs: - output_str = sample.text - output_ids = list(sample.token_ids) - req_sample_output_ids.append(prompt_ids + output_ids) - req_sample_output_strs.append((prompt_str or "") + output_str) - outputs.append((req_sample_output_ids, req_sample_output_strs)) - return outputs + return self._finalize_generate_outputs(req_outputs) @staticmethod def _final_steps_generate_w_logprobs( @@ -760,7 +885,7 @@ class VllmRunner: def generate_greedy( self, - prompts: list[str] | list[torch.Tensor], + prompts: list[str] | list[torch.Tensor] | list[list[int]], max_tokens: int, images: PromptImageInput | None = None, videos: PromptVideoInput | None = None, @@ -842,6 +967,319 @@ class VllmRunner: cleanup_dist_env_and_memory() +class DPVllmRunner(VllmRunner): + def __init__( + self, + model_name: str, + runner: RunnerOption = "auto", + convert: ConvertOption = "auto", + tokenizer_name: str | None = None, + tokenizer_mode: str = "auto", + max_model_len: int | None = 1024, + dtype: str = "auto", + disable_log_stats: bool = True, + tensor_parallel_size: int = 1, + block_size: int = 16, + enable_chunked_prefill: bool = True, + swap_space: int = 4, + enforce_eager: bool | None = False, + quantization: str | None = None, + data_parallel_size: int = 2, + **kwargs, + ) -> None: + if data_parallel_size < 2: + raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`") + + self._dp_size = data_parallel_size + self._dp_parent_conns: list[Any] = [] + self._dp_processes: list[Any] = [] + self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS)) + self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS)) + + llm_kwargs = dict( + model=model_name, + runner=runner, + convert=convert, + tokenizer=tokenizer_name, + tokenizer_mode=tokenizer_mode, + trust_remote_code=True, + dtype=dtype, + swap_space=swap_space, + enforce_eager=enforce_eager, + disable_log_stats=disable_log_stats, + tensor_parallel_size=tensor_parallel_size, + max_model_len=max_model_len, + block_size=block_size, + enable_chunked_prefill=enable_chunked_prefill, + quantization=quantization, + **kwargs, + ) + + cleanup_dist_env_and_memory() + self._start_data_parallel_workers(llm_kwargs) + + @property + def model(self) -> LLM: + raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.") + + def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None: + ctx = multiprocessing.get_context("spawn") + master_port = get_open_port() + + try: + for dp_rank in range(self._dp_size): + parent_conn, child_conn = ctx.Pipe() + proc = ctx.Process( + target=_run_vllm_runner_dp_worker, + args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port), + ) + proc.start() + child_conn.close() + self._dp_parent_conns.append(parent_conn) + self._dp_processes.append(proc) + + for rank, conn in enumerate(self._dp_parent_conns): + if not conn.poll(self._dp_start_timeout): + raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start") + message = conn.recv() + if message["status"] != "ready": + raise RuntimeError( + f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}" + ) + except Exception: + self._stop_data_parallel_workers() + raise + + def _stop_data_parallel_workers(self) -> None: + for conn in self._dp_parent_conns: + with contextlib.suppress(Exception): + conn.send({"command": "shutdown"}) + + for proc in self._dp_processes: + proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS) + if proc.is_alive(): + proc.kill() + proc.join(timeout=5) + + for conn in self._dp_parent_conns: + with contextlib.suppress(Exception): + conn.close() + + self._dp_parent_conns.clear() + self._dp_processes.clear() + + def _dispatch_prompt_command( + self, + command: str, + prompts: list[str] | list[torch.Tensor] | list[list[int]], + *, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + **payload: Any, + ) -> list[Any]: + if not prompts: + return [] + + shard_results: list[tuple[list[int], list[Any]]] = [] + shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size) + + for rank, conn in enumerate(self._dp_parent_conns): + indices = shard_indices[rank] + worker_indices = indices or [0] + worker_prompts = _slice_list_inputs(prompts, worker_indices) + conn.send( + { + "command": command, + "indices": indices, + "inputs": self.get_inputs( + worker_prompts, + images=_slice_optional_inputs(images, worker_indices), + videos=_slice_optional_inputs(videos, worker_indices), + audios=_slice_optional_inputs(audios, worker_indices), + ), + "prompts": worker_prompts, + **payload, + } + ) + + try: + for rank, conn in enumerate(self._dp_parent_conns): + if not conn.poll(self._dp_request_timeout): + raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`") + message = conn.recv() + if message["status"] != "ok": + raise RuntimeError( + f"Data parallel worker {rank} failed during `{command}`:\n" + f"{message.get('traceback', 'unknown error')}" + ) + shard_results.append((message["indices"], message["result"])) + except Exception: + self._stop_data_parallel_workers() + raise + + return _merge_data_parallel_results(len(prompts), shard_results) + + def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]: + if not prompts: + return [] + + shard_results: list[tuple[list[int], list[Any]]] = [] + shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size) + + for rank, conn in enumerate(self._dp_parent_conns): + indices = shard_indices[rank] + worker_indices = indices or [0] + conn.send( + { + "command": command, + "indices": indices, + "prompts": _slice_list_inputs(prompts, worker_indices), + } + ) + + try: + for rank, conn in enumerate(self._dp_parent_conns): + if not conn.poll(self._dp_request_timeout): + raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`") + message = conn.recv() + if message["status"] != "ok": + raise RuntimeError( + f"Data parallel worker {rank} failed during `{command}`:\n" + f"{message.get('traceback', 'unknown error')}" + ) + shard_results.append((message["indices"], message["result"])) + except Exception: + self._stop_data_parallel_workers() + raise + + return _merge_data_parallel_results(len(prompts), shard_results) + + def generate( + self, + prompts: list[str] | list[torch.Tensor] | list[list[int]], + sampling_params: SamplingParams, + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + **kwargs: Any, + ) -> list[tuple[list[list[int]], list[str]]]: + return self._dispatch_prompt_command( + "generate", + prompts, + images=images, + videos=videos, + audios=audios, + sampling_params=sampling_params, + kwargs=kwargs, + ) + + def generate_w_logprobs( + self, + prompts: list[str], + sampling_params: SamplingParams, + images: PromptImageInput | None = None, + audios: PromptAudioInput | None = None, + videos: PromptVideoInput | None = None, + **kwargs: Any, + ) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]: + toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command( + "generate_w_logprobs", + prompts, + images=images, + videos=videos, + audios=audios, + sampling_params=sampling_params, + kwargs=kwargs, + ) + return ( + [x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] + if sampling_params.prompt_logprobs is None + else toks_str_logsprobs_prompt_logprobs + ) + + def classify(self, prompts: list[str]) -> list[list[float]]: + return self._dispatch_text_command("classify", prompts) + + def embed( + self, + prompts: list[str], + images: PromptImageInput | None = None, + videos: PromptVideoInput | None = None, + audios: PromptAudioInput | None = None, + *args, + **kwargs, + ) -> list[list[float]]: + return self._dispatch_prompt_command( + "embed", + prompts, + images=images, + videos=videos, + audios=audios, + args=args, + kwargs=kwargs, + ) + + def encode(self, prompts: list[str]) -> list[list[float]]: + return self._dispatch_text_command("encode", prompts) + + def reward(self, prompts: list[str]) -> list[list[float]]: + return self._dispatch_text_command("reward", prompts) + + def score( + self, + text_1: str | list[str], + text_2: str | list[str], + *args, + **kwargs, + ) -> list[float]: + normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2) + if not normalized_text_1: + return [] + + shard_results: list[tuple[list[int], list[Any]]] = [] + shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size) + + for rank, conn in enumerate(self._dp_parent_conns): + indices = shard_indices[rank] + worker_indices = indices or [0] + conn.send( + { + "command": "score", + "indices": indices, + "text_1": _slice_list_inputs(normalized_text_1, worker_indices), + "text_2": _slice_list_inputs(normalized_text_2, worker_indices), + "args": args, + "kwargs": kwargs, + } + ) + + try: + for rank, conn in enumerate(self._dp_parent_conns): + if not conn.poll(self._dp_request_timeout): + raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`") + message = conn.recv() + if message["status"] != "ok": + raise RuntimeError( + f"Data parallel worker {rank} failed during `score`:\n" + f"{message.get('traceback', 'unknown error')}" + ) + shard_results.append((message["indices"], message["result"])) + except Exception: + self._stop_data_parallel_workers() + raise + + return _merge_data_parallel_results(len(normalized_text_1), shard_results) + + def __exit__(self, exc_type, exc_value, traceback): + self._stop_data_parallel_workers() + clear_ascend_config() + cleanup_dist_env_and_memory() + + +DataParallelVllmRunner = DPVllmRunner + + class HfRunner: def get_default_device(self): return "cpu" if current_platform.is_cpu() else current_platform.device_type diff --git a/tests/e2e/multicard/2-cards/test_pipeline_parallel.py b/tests/e2e/multicard/2-cards/test_pipeline_parallel.py deleted file mode 100644 index 4023ee0e..00000000 --- a/tests/e2e/multicard/2-cards/test_pipeline_parallel.py +++ /dev/null @@ -1,49 +0,0 @@ -# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. -# Copyright 2023 The vLLM team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# This file is a part of the vllm-ascend project. -# -import pytest - -from tests.e2e.conftest import VllmRunner - -MODELS = [ - "Qwen/Qwen3-0.6B", - "deepseek-ai/DeepSeek-V2-Lite-Chat", -] - -TENSOR_PARALLELS = [1] -PIPELINE_PARALLELS = [2] -DIST_EXECUTOR_BACKEND = ["mp", "ray"] - -prompts = [ - "Hello, my name is", - "The future of AI is", -] - - -@pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) -@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS) -@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND) -def test_models_pp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None: - with VllmRunner( - model, - tensor_parallel_size=tp_size, - pipeline_parallel_size=pp_size, - cudagraph_capture_sizes=[1, 2, 4, 8], - distributed_executor_backend=distributed_executor_backend, - gpu_memory_utilization=0.7, - ) as vllm_model: - vllm_model.generate_greedy(prompts, 64) diff --git a/tests/e2e/multicard/4-cards/test_pipeline_parallel.py b/tests/e2e/multicard/4-cards/test_pipeline_parallel.py new file mode 100644 index 00000000..4d086f9d --- /dev/null +++ b/tests/e2e/multicard/4-cards/test_pipeline_parallel.py @@ -0,0 +1,88 @@ +# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. +# Copyright 2023 The vLLM team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# This file is a part of the vllm-ascend project. +# +import pytest + +from tests.e2e.conftest import DPVllmRunner, VllmRunner, wait_until_npu_memory_free +from tests.e2e.model_utils import check_outputs_equal + +DS3 = "deepseek-ai/DeepSeek-V2-Lite-Chat" +MODELS = [ + DS3, +] +MOE_MODELS = [ + DS3, +] + +DATA_PARALLELS = [2] +TENSOR_PARALLELS = [1,2] +PIPELINE_PARALLELS = [2] +DIST_EXECUTOR_BACKEND = ["mp", "ray"] + +prompts = [ + "Hello, my name is", + "The future of AI is", +] +GOLDEN = [([100000, 17464, 11, 601, 1210, 317, 46462, 608, 245, 4541, 7712, 13, 2682, 6207, 317, 276, 2774, 340, 366, 254, 1608, 2784], 'Hello, my name is***** am a computer expert. My goal is to provide you with the best experience'), ([100000, 549, 3680, 280, 20838, 317, 6464, 11, 548, 359, 487, 82, 441, 1673, 895, 10694, 13, 1733, 20838, 5495, 11106, 276], 'The future of AI is bright, but it’s not without its challenges. As AI technology continues to')] + + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS) +@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS) +@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND) +@wait_until_npu_memory_free(target_free_percentage=0.6) +def test_models_pp2_tp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None: + with VllmRunner( + model, + tensor_parallel_size=tp_size, + pipeline_parallel_size=pp_size, + cudagraph_capture_sizes=[1, 2, 4], + distributed_executor_backend=distributed_executor_backend, + gpu_memory_utilization=0.7, + enable_expert_parallel=model in MOE_MODELS, + ) as vllm_model: + outputs = vllm_model.generate_greedy(prompts, 16) + check_outputs_equal( + outputs_0_lst=outputs, + outputs_1_lst=GOLDEN, + name_0=f"{model}-tp{tp_size}pp{pp_size}", + name_1="GOLDEN", + ) + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dp_size", DATA_PARALLELS) +@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS) +@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND) +@wait_until_npu_memory_free(target_free_percentage=0.6) +def test_models_pp2_dp2(model: str, dp_size: int, pp_size: int, distributed_executor_backend: str) -> None: + with DPVllmRunner( + model, + data_parallel_size=dp_size, + pipeline_parallel_size=pp_size, + cudagraph_capture_sizes=[1, 2, 4], + distributed_executor_backend=distributed_executor_backend, + gpu_memory_utilization=0.7, + enable_expert_parallel=model in MOE_MODELS, + ) as vllm_model: + outputs = vllm_model.generate_greedy(prompts, 16) + check_outputs_equal( + outputs_0_lst=outputs, + outputs_1_lst=GOLDEN, + name_0=f"{model}-dp{dp_size}pp{pp_size}", + name_1="GOLDEN", + ) + diff --git a/vllm_ascend/distributed/parallel_state.py b/vllm_ascend/distributed/parallel_state.py index 5f1c37f1..5ed7d3dd 100644 --- a/vllm_ascend/distributed/parallel_state.py +++ b/vllm_ascend/distributed/parallel_state.py @@ -38,19 +38,16 @@ def init_ascend_model_parallel( global_tp_size = parallel_config.tensor_parallel_size global_dp_size = parallel_config.data_parallel_size global_pp_size = parallel_config.pipeline_parallel_size + global_pcp_size = parallel_config.prefill_context_parallel_size # The layout of all ranks: ExternalDP * EP # ExternalDP is the data parallel group that is not part of the model, # every dp rank can generate independently (in verl integration). all_ranks = torch.arange(world_size).reshape( - -1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size - ) - # TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future. - vllm_all_ranks = torch.arange(world_size).reshape( -1, global_dp_size, global_pp_size, - parallel_config.prefill_context_parallel_size, + global_pcp_size, global_tp_size, ) @@ -59,7 +56,6 @@ def init_ascend_model_parallel( global _P_TP assert _P_TP is None, "distributed prefill tensor parallel group is already initialized" prefill_tensor_model_parallel_size = pd_tp_ratio - pcp_size = parallel_config.prefill_context_parallel_size # divide alltoall groups if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer: num_head_replica = get_ascend_config().num_head_replica @@ -68,13 +64,16 @@ def init_ascend_model_parallel( group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0) else: group_ranks = all_ranks.clone().view( - global_dp_size * pcp_size, -1, num_head_replica + global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica ) # [DP_size, num_head, num_head_replica] group_ranks = group_ranks.permute(0, 2, 1) group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head] alltoall_group_size = group_ranks.size(-1) // remote_tp_size group_ranks = group_ranks.unsqueeze(-1).view( - global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size + global_dp_size * global_pp_size * global_pcp_size, + num_head_replica, + -1, + alltoall_group_size, ) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size] group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0) group_ranks = [x.tolist() for x in group_ranks] @@ -82,10 +81,18 @@ def init_ascend_model_parallel( num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None) _P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}") - global _MC2 - group_ranks = all_ranks.unbind(0) + # EP like group ranks + group_ranks = ( + all_ranks.transpose(1, 2) + .reshape( + -1, + global_dp_size * global_pcp_size * global_tp_size, + ) + .unbind(0) + ) group_ranks = [x.tolist() for x in group_ranks] + global _MC2 _MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2") if get_ascend_config().eplb_config.dynamic_eplb: @@ -94,6 +101,12 @@ def init_ascend_model_parallel( group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb" ) + if get_ascend_config().multistream_overlap_gate: + global _FC3_QUANT_X + _FC3_QUANT_X = init_model_parallel_group( + group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x" + ) + # Initialize fine-grained TP process groups on Ascend for four components: # 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`) # 2. O Proj: attention output projection (`oproj_tensor_parallel_size`) @@ -182,7 +195,7 @@ def init_ascend_model_parallel( # 2. If it is not None, and the module tp_group is same as the global tp_group. # 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp) group_ranks = [] - pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size) + pp_group_ranks = all_ranks.transpose(2, 4).reshape(-1, global_pp_size) if module_tp_group_ranks is None: # If it is None, then the TP_size of this shard weight is 1. shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0) @@ -209,17 +222,9 @@ def init_ascend_model_parallel( _SHARD_WEIGHT = create_shard_weight_group(None) else: # For standard tp, use global tp group_ranks - tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size) + tp_group_ranks = all_ranks.view(-1, global_tp_size) _SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks) - if get_ascend_config().multistream_overlap_gate: - global _FC3_QUANT_X - group_ranks = all_ranks.unbind(0) - group_ranks = [x.tolist() for x in group_ranks] - _FC3_QUANT_X = init_model_parallel_group( - group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x" - ) - def model_parallel_initialized(): return _MC2 is not None