bugfix(MC2): refactor the comm group of MC2 to be compatible with PP (#7291)

### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.

- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
This commit is contained in:
Qiu
2026-03-23 15:44:21 +08:00
committed by GitHub
parent 8527b49764
commit 71df17f4e6
5 changed files with 571 additions and 89 deletions

View File

@@ -127,8 +127,6 @@ e2e-multicard-2-cards:
estimated_time: 180
- name: tests/e2e/multicard/2-cards/test_offline_inference_distributed.py::test_qwen3_w4a4_distributed_tp2
estimated_time: 202
- name: tests/e2e/multicard/2-cards/test_pipeline_parallel.py
estimated_time: 357
- name: tests/e2e/multicard/2-cards/test_prefix_caching.py
estimated_time: 470
- name: tests/e2e/multicard/2-cards/test_quantization.py
@@ -165,3 +163,5 @@ e2e-multicard-4-cards:
is_skipped: true
- name: tests/e2e/multicard/4-cards/spec_decode/test_mtp_qwen3_next.py
estimated_time: 1340
- name: tests/e2e/multicard/4-cards/test_pipeline_parallel.py
estimated_time: 357

View File

@@ -453,7 +453,6 @@ class RemoteEPDServer(RemoteOpenAIServer):
self.env_dict.update(env_dict)
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
self.env_dict["VLLM_USE_V1"] = "1"
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -626,6 +625,126 @@ class DisaggEpdProxy(RemoteEPDServer):
super()._terminate_server()
_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0
def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
if num_items < 0:
raise ValueError("num_items must be non-negative")
if dp_size <= 0:
raise ValueError("dp_size must be positive")
floor = num_items // dp_size
remainder = num_items % dp_size
def start(rank: int) -> int:
return rank * floor + min(rank, remainder)
return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]
def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
if inputs is None:
return None
return [inputs[index] for index in indices]
def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
return [items[index] for index in indices]
def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
merged: list[Any] = [None] * total_items
for indices, results in shard_results:
if not indices:
continue
if len(indices) != len(results):
raise RuntimeError("Mismatched result count returned by data parallel worker")
for index, result in zip(indices, results):
merged[index] = result
if any(result is None for result in merged):
raise RuntimeError("Some data parallel results were not returned")
return merged
def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
if isinstance(text_1, str) and isinstance(text_2, str):
return [text_1], [text_2]
if isinstance(text_1, str):
return [text_1] * len(text_2), list(text_2)
if isinstance(text_2, str):
return list(text_1), [text_2] * len(text_1)
if len(text_1) != len(text_2):
raise ValueError("`text_1` and `text_2` must have the same length")
return list(text_1), list(text_2)
def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
llm = None
try:
os.environ["VLLM_DP_RANK"] = str(dp_rank)
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
os.environ["VLLM_DP_SIZE"] = str(dp_size)
os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
llm = LLM(**llm_kwargs)
conn.send({"status": "ready", "rank": dp_rank})
while True:
request = conn.recv()
command = request["command"]
if command == "shutdown":
break
result: Any
if command == "generate":
req_outputs = llm.generate(
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
)
result = VllmRunner._finalize_generate_outputs(req_outputs)
elif command == "generate_w_logprobs":
req_outputs = llm.generate(
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
)
result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
elif command == "classify":
req_outputs = llm.classify(request["prompts"])
result = [req_output.outputs.probs for req_output in req_outputs]
elif command == "embed":
req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
result = [req_output.outputs.embedding for req_output in req_outputs]
elif command == "encode":
req_outputs = llm.encode(request["prompts"])
result = [req_output.outputs.data for req_output in req_outputs]
elif command == "reward":
req_outputs = llm.reward(request["prompts"])
result = [req_output.outputs.data for req_output in req_outputs]
elif command == "score":
req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
result = [req_output.outputs.score for req_output in req_outputs]
else:
raise ValueError(f"Unsupported data parallel command: {command}")
conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
except Exception:
with contextlib.suppress(Exception):
conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
raise
finally:
if llm is not None:
del llm
clear_ascend_config()
cleanup_dist_env_and_memory()
with contextlib.suppress(Exception):
conn.close()
class VllmRunner:
def __init__(
self,
@@ -645,6 +764,10 @@ class VllmRunner:
quantization: str | None = None,
**kwargs,
) -> None:
data_parallel_size = int(kwargs.get("data_parallel_size", 1))
if data_parallel_size > 1:
raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")
self.model = LLM(
model=model_name,
runner=runner,
@@ -664,6 +787,22 @@ class VllmRunner:
**kwargs,
)
@staticmethod
def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
outputs: list[tuple[list[list[int]], list[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: list[list[int]] = []
req_sample_output_strs: list[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
def get_inputs(
self,
prompts: list[str] | list[torch.Tensor] | list[int],
@@ -698,7 +837,7 @@ class VllmRunner:
def generate(
self,
prompts: list[str] | list[torch.Tensor],
prompts: list[str] | list[torch.Tensor] | list[list[int]],
sampling_params: SamplingParams,
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
@@ -706,22 +845,8 @@ class VllmRunner:
**kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]:
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
outputs: list[tuple[list[list[int]], list[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: list[list[int]] = []
req_sample_output_strs: list[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append((prompt_str or "") + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
return self._finalize_generate_outputs(req_outputs)
@staticmethod
def _final_steps_generate_w_logprobs(
@@ -760,7 +885,7 @@ class VllmRunner:
def generate_greedy(
self,
prompts: list[str] | list[torch.Tensor],
prompts: list[str] | list[torch.Tensor] | list[list[int]],
max_tokens: int,
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
@@ -842,6 +967,319 @@ class VllmRunner:
cleanup_dist_env_and_memory()
class DPVllmRunner(VllmRunner):
def __init__(
self,
model_name: str,
runner: RunnerOption = "auto",
convert: ConvertOption = "auto",
tokenizer_name: str | None = None,
tokenizer_mode: str = "auto",
max_model_len: int | None = 1024,
dtype: str = "auto",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = True,
swap_space: int = 4,
enforce_eager: bool | None = False,
quantization: str | None = None,
data_parallel_size: int = 2,
**kwargs,
) -> None:
if data_parallel_size < 2:
raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")
self._dp_size = data_parallel_size
self._dp_parent_conns: list[Any] = []
self._dp_processes: list[Any] = []
self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))
llm_kwargs = dict(
model=model_name,
runner=runner,
convert=convert,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
quantization=quantization,
**kwargs,
)
cleanup_dist_env_and_memory()
self._start_data_parallel_workers(llm_kwargs)
@property
def model(self) -> LLM:
raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")
def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
ctx = multiprocessing.get_context("spawn")
master_port = get_open_port()
try:
for dp_rank in range(self._dp_size):
parent_conn, child_conn = ctx.Pipe()
proc = ctx.Process(
target=_run_vllm_runner_dp_worker,
args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
)
proc.start()
child_conn.close()
self._dp_parent_conns.append(parent_conn)
self._dp_processes.append(proc)
for rank, conn in enumerate(self._dp_parent_conns):
if not conn.poll(self._dp_start_timeout):
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
message = conn.recv()
if message["status"] != "ready":
raise RuntimeError(
f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
)
except Exception:
self._stop_data_parallel_workers()
raise
def _stop_data_parallel_workers(self) -> None:
for conn in self._dp_parent_conns:
with contextlib.suppress(Exception):
conn.send({"command": "shutdown"})
for proc in self._dp_processes:
proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
if proc.is_alive():
proc.kill()
proc.join(timeout=5)
for conn in self._dp_parent_conns:
with contextlib.suppress(Exception):
conn.close()
self._dp_parent_conns.clear()
self._dp_processes.clear()
def _dispatch_prompt_command(
self,
command: str,
prompts: list[str] | list[torch.Tensor] | list[list[int]],
*,
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None,
**payload: Any,
) -> list[Any]:
if not prompts:
return []
shard_results: list[tuple[list[int], list[Any]]] = []
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
for rank, conn in enumerate(self._dp_parent_conns):
indices = shard_indices[rank]
worker_indices = indices or [0]
worker_prompts = _slice_list_inputs(prompts, worker_indices)
conn.send(
{
"command": command,
"indices": indices,
"inputs": self.get_inputs(
worker_prompts,
images=_slice_optional_inputs(images, worker_indices),
videos=_slice_optional_inputs(videos, worker_indices),
audios=_slice_optional_inputs(audios, worker_indices),
),
"prompts": worker_prompts,
**payload,
}
)
try:
for rank, conn in enumerate(self._dp_parent_conns):
if not conn.poll(self._dp_request_timeout):
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
message = conn.recv()
if message["status"] != "ok":
raise RuntimeError(
f"Data parallel worker {rank} failed during `{command}`:\n"
f"{message.get('traceback', 'unknown error')}"
)
shard_results.append((message["indices"], message["result"]))
except Exception:
self._stop_data_parallel_workers()
raise
return _merge_data_parallel_results(len(prompts), shard_results)
def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
if not prompts:
return []
shard_results: list[tuple[list[int], list[Any]]] = []
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
for rank, conn in enumerate(self._dp_parent_conns):
indices = shard_indices[rank]
worker_indices = indices or [0]
conn.send(
{
"command": command,
"indices": indices,
"prompts": _slice_list_inputs(prompts, worker_indices),
}
)
try:
for rank, conn in enumerate(self._dp_parent_conns):
if not conn.poll(self._dp_request_timeout):
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
message = conn.recv()
if message["status"] != "ok":
raise RuntimeError(
f"Data parallel worker {rank} failed during `{command}`:\n"
f"{message.get('traceback', 'unknown error')}"
)
shard_results.append((message["indices"], message["result"]))
except Exception:
self._stop_data_parallel_workers()
raise
return _merge_data_parallel_results(len(prompts), shard_results)
def generate(
self,
prompts: list[str] | list[torch.Tensor] | list[list[int]],
sampling_params: SamplingParams,
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None,
**kwargs: Any,
) -> list[tuple[list[list[int]], list[str]]]:
return self._dispatch_prompt_command(
"generate",
prompts,
images=images,
videos=videos,
audios=audios,
sampling_params=sampling_params,
kwargs=kwargs,
)
def generate_w_logprobs(
self,
prompts: list[str],
sampling_params: SamplingParams,
images: PromptImageInput | None = None,
audios: PromptAudioInput | None = None,
videos: PromptVideoInput | None = None,
**kwargs: Any,
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
"generate_w_logprobs",
prompts,
images=images,
videos=videos,
audios=audios,
sampling_params=sampling_params,
kwargs=kwargs,
)
return (
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None
else toks_str_logsprobs_prompt_logprobs
)
def classify(self, prompts: list[str]) -> list[list[float]]:
return self._dispatch_text_command("classify", prompts)
def embed(
self,
prompts: list[str],
images: PromptImageInput | None = None,
videos: PromptVideoInput | None = None,
audios: PromptAudioInput | None = None,
*args,
**kwargs,
) -> list[list[float]]:
return self._dispatch_prompt_command(
"embed",
prompts,
images=images,
videos=videos,
audios=audios,
args=args,
kwargs=kwargs,
)
def encode(self, prompts: list[str]) -> list[list[float]]:
return self._dispatch_text_command("encode", prompts)
def reward(self, prompts: list[str]) -> list[list[float]]:
return self._dispatch_text_command("reward", prompts)
def score(
self,
text_1: str | list[str],
text_2: str | list[str],
*args,
**kwargs,
) -> list[float]:
normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
if not normalized_text_1:
return []
shard_results: list[tuple[list[int], list[Any]]] = []
shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)
for rank, conn in enumerate(self._dp_parent_conns):
indices = shard_indices[rank]
worker_indices = indices or [0]
conn.send(
{
"command": "score",
"indices": indices,
"text_1": _slice_list_inputs(normalized_text_1, worker_indices),
"text_2": _slice_list_inputs(normalized_text_2, worker_indices),
"args": args,
"kwargs": kwargs,
}
)
try:
for rank, conn in enumerate(self._dp_parent_conns):
if not conn.poll(self._dp_request_timeout):
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
message = conn.recv()
if message["status"] != "ok":
raise RuntimeError(
f"Data parallel worker {rank} failed during `score`:\n"
f"{message.get('traceback', 'unknown error')}"
)
shard_results.append((message["indices"], message["result"]))
except Exception:
self._stop_data_parallel_workers()
raise
return _merge_data_parallel_results(len(normalized_text_1), shard_results)
def __exit__(self, exc_type, exc_value, traceback):
self._stop_data_parallel_workers()
clear_ascend_config()
cleanup_dist_env_and_memory()
DataParallelVllmRunner = DPVllmRunner
class HfRunner:
def get_default_device(self):
return "cpu" if current_platform.is_cpu() else current_platform.device_type

View File

@@ -1,49 +0,0 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
from tests.e2e.conftest import VllmRunner
MODELS = [
"Qwen/Qwen3-0.6B",
"deepseek-ai/DeepSeek-V2-Lite-Chat",
]
TENSOR_PARALLELS = [1]
PIPELINE_PARALLELS = [2]
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
prompts = [
"Hello, my name is",
"The future of AI is",
]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
def test_models_pp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
with VllmRunner(
model,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
cudagraph_capture_sizes=[1, 2, 4, 8],
distributed_executor_backend=distributed_executor_backend,
gpu_memory_utilization=0.7,
) as vllm_model:
vllm_model.generate_greedy(prompts, 64)

View File

@@ -0,0 +1,88 @@
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
#
import pytest
from tests.e2e.conftest import DPVllmRunner, VllmRunner, wait_until_npu_memory_free
from tests.e2e.model_utils import check_outputs_equal
DS3 = "deepseek-ai/DeepSeek-V2-Lite-Chat"
MODELS = [
DS3,
]
MOE_MODELS = [
DS3,
]
DATA_PARALLELS = [2]
TENSOR_PARALLELS = [1,2]
PIPELINE_PARALLELS = [2]
DIST_EXECUTOR_BACKEND = ["mp", "ray"]
prompts = [
"Hello, my name is",
"The future of AI is",
]
GOLDEN = [([100000, 17464, 11, 601, 1210, 317, 46462, 608, 245, 4541, 7712, 13, 2682, 6207, 317, 276, 2774, 340, 366, 254, 1608, 2784], 'Hello, my name is***** am a computer expert. My goal is to provide you with the best experience'), ([100000, 549, 3680, 280, 20838, 317, 6464, 11, 548, 359, 487, 82, 441, 1673, 895, 10694, 13, 1733, 20838, 5495, 11106, 276], 'The future of AI is bright, but its not without its challenges. As AI technology continues to')]
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("tp_size", TENSOR_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
@wait_until_npu_memory_free(target_free_percentage=0.6)
def test_models_pp2_tp2(model: str, tp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
with VllmRunner(
model,
tensor_parallel_size=tp_size,
pipeline_parallel_size=pp_size,
cudagraph_capture_sizes=[1, 2, 4],
distributed_executor_backend=distributed_executor_backend,
gpu_memory_utilization=0.7,
enable_expert_parallel=model in MOE_MODELS,
) as vllm_model:
outputs = vllm_model.generate_greedy(prompts, 16)
check_outputs_equal(
outputs_0_lst=outputs,
outputs_1_lst=GOLDEN,
name_0=f"{model}-tp{tp_size}pp{pp_size}",
name_1="GOLDEN",
)
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dp_size", DATA_PARALLELS)
@pytest.mark.parametrize("pp_size", PIPELINE_PARALLELS)
@pytest.mark.parametrize("distributed_executor_backend", DIST_EXECUTOR_BACKEND)
@wait_until_npu_memory_free(target_free_percentage=0.6)
def test_models_pp2_dp2(model: str, dp_size: int, pp_size: int, distributed_executor_backend: str) -> None:
with DPVllmRunner(
model,
data_parallel_size=dp_size,
pipeline_parallel_size=pp_size,
cudagraph_capture_sizes=[1, 2, 4],
distributed_executor_backend=distributed_executor_backend,
gpu_memory_utilization=0.7,
enable_expert_parallel=model in MOE_MODELS,
) as vllm_model:
outputs = vllm_model.generate_greedy(prompts, 16)
check_outputs_equal(
outputs_0_lst=outputs,
outputs_1_lst=GOLDEN,
name_0=f"{model}-dp{dp_size}pp{pp_size}",
name_1="GOLDEN",
)

View File

@@ -38,19 +38,16 @@ def init_ascend_model_parallel(
global_tp_size = parallel_config.tensor_parallel_size
global_dp_size = parallel_config.data_parallel_size
global_pp_size = parallel_config.pipeline_parallel_size
global_pcp_size = parallel_config.prefill_context_parallel_size
# The layout of all ranks: ExternalDP * EP
# ExternalDP is the data parallel group that is not part of the model,
# every dp rank can generate independently (in verl integration).
all_ranks = torch.arange(world_size).reshape(
-1, global_dp_size * parallel_config.prefill_context_parallel_size * global_tp_size
)
# TODO: all_ranks should be the same as vllm_all_ranks, all_ranks needs to be removed in the future.
vllm_all_ranks = torch.arange(world_size).reshape(
-1,
global_dp_size,
global_pp_size,
parallel_config.prefill_context_parallel_size,
global_pcp_size,
global_tp_size,
)
@@ -59,7 +56,6 @@ def init_ascend_model_parallel(
global _P_TP
assert _P_TP is None, "distributed prefill tensor parallel group is already initialized"
prefill_tensor_model_parallel_size = pd_tp_ratio
pcp_size = parallel_config.prefill_context_parallel_size
# divide alltoall groups
if pd_head_ratio > 1 and get_current_vllm_config().kv_transfer_config.is_kv_producer:
num_head_replica = get_ascend_config().num_head_replica
@@ -68,13 +64,16 @@ def init_ascend_model_parallel(
group_ranks = all_ranks.view(-1, prefill_tensor_model_parallel_size).unbind(0)
else:
group_ranks = all_ranks.clone().view(
global_dp_size * pcp_size, -1, num_head_replica
global_dp_size * global_pp_size * global_pcp_size, -1, num_head_replica
) # [DP_size, num_head, num_head_replica]
group_ranks = group_ranks.permute(0, 2, 1)
group_ranks = group_ranks.reshape(-1, group_ranks.size(-1)) # [DP_size * num_head_replica, num_head]
alltoall_group_size = group_ranks.size(-1) // remote_tp_size
group_ranks = group_ranks.unsqueeze(-1).view(
global_dp_size * pcp_size, num_head_replica, -1, alltoall_group_size
global_dp_size * global_pp_size * global_pcp_size,
num_head_replica,
-1,
alltoall_group_size,
) # [DP_size, num_head_replica, num_alltoall_group, alltoall_group_size]
group_ranks = group_ranks.reshape(-1, alltoall_group_size).unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
@@ -82,10 +81,18 @@ def init_ascend_model_parallel(
num = next((i for i, ranks in enumerate(group_ranks) if local_rank in ranks), None)
_P_TP = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name=f"p_tp_{num}")
global _MC2
group_ranks = all_ranks.unbind(0)
# EP like group ranks
group_ranks = (
all_ranks.transpose(1, 2)
.reshape(
-1,
global_dp_size * global_pcp_size * global_tp_size,
)
.unbind(0)
)
group_ranks = [x.tolist() for x in group_ranks]
global _MC2
_MC2 = init_model_parallel_group(group_ranks, get_world_group().local_rank, backend, group_name="mc2")
if get_ascend_config().eplb_config.dynamic_eplb:
@@ -94,6 +101,12 @@ def init_ascend_model_parallel(
group_ranks, get_world_group().local_rank, backend, group_name="dynamic_eplb"
)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
_FC3_QUANT_X = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
)
# Initialize fine-grained TP process groups on Ascend for four components:
# 1. LM Head: output logits projection (`lmhead_tensor_parallel_size`)
# 2. O Proj: attention output projection (`oproj_tensor_parallel_size`)
@@ -182,7 +195,7 @@ def init_ascend_model_parallel(
# 2. If it is not None, and the module tp_group is same as the global tp_group.
# 3. If it is not None, and the module tp_group is different from the global tp_group.(eg. flashcomm2_otp)
group_ranks = []
pp_group_ranks = vllm_all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
pp_group_ranks = all_ranks.transpose(2, 4).reshape(-1, global_pp_size)
if module_tp_group_ranks is None:
# If it is None, then the TP_size of this shard weight is 1.
shard_weight_group_ranks = pp_group_ranks.transpose(0, 1).unbind(0)
@@ -209,17 +222,9 @@ def init_ascend_model_parallel(
_SHARD_WEIGHT = create_shard_weight_group(None)
else:
# For standard tp, use global tp group_ranks
tp_group_ranks = vllm_all_ranks.view(-1, global_tp_size)
tp_group_ranks = all_ranks.view(-1, global_tp_size)
_SHARD_WEIGHT = create_shard_weight_group(tp_group_ranks)
if get_ascend_config().multistream_overlap_gate:
global _FC3_QUANT_X
group_ranks = all_ranks.unbind(0)
group_ranks = [x.tolist() for x in group_ranks]
_FC3_QUANT_X = init_model_parallel_group(
group_ranks, get_world_group().local_rank, backend, group_name="fc3_quant_x"
)
def model_parallel_initialized():
return _MC2 is not None