[Bugfix] fix dp parallel + tp > 1 offline inference port conflict (#4539)
### What this PR does / why we need it? fix dp parallel + tp > 1 offline inference port conflict issue import PR:https://github.com/vllm-project/vllm-ascend/pull/429 - vLLM version: v0.11.2 - vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.2 --------- Signed-off-by: leo-pony <nengjunma@outlook.com>
This commit is contained in:
1
.github/workflows/_e2e_test.yaml
vendored
1
.github/workflows/_e2e_test.yaml
vendored
@@ -269,6 +269,7 @@ jobs:
|
|||||||
tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
tests/e2e/multicard/test_offline_inference_distributed.py::test_models_distributed_DeepSeek_W4A8DYNAMIC
|
||||||
# tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP \
|
# tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_TP2_WITH_EP \
|
||||||
# tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP
|
# tests/e2e/multicard/test_qwen3_moe.py::test_models_distributed_Qwen3_MOE_W8A8_WITH_EP
|
||||||
|
pytest -sv tests/e2e/multicard/test_data_parallel_tp2.py
|
||||||
|
|
||||||
- name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct)
|
- name: Install Ascend toolkit & triton_ascend (for Qwen3-Next-80B-A3B-Instruct)
|
||||||
shell: bash -l {0}
|
shell: bash -l {0}
|
||||||
|
|||||||
52
tests/e2e/multicard/test_data_parallel_tp2.py
Normal file
52
tests/e2e/multicard/test_data_parallel_tp2.py
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
"""
|
||||||
|
Run `pytest tests/e2e/multicard/test_data_parallel_tp2.py`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
MODELS = ["Qwen/Qwen3-0.6B"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
|
@pytest.mark.parametrize("max_tokens", [32])
|
||||||
|
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
|
||||||
|
def test_data_parallel_inference(model, max_tokens):
|
||||||
|
script = "examples/offline_data_parallel.py"
|
||||||
|
|
||||||
|
env = os.environ.copy()
|
||||||
|
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
script,
|
||||||
|
"--model",
|
||||||
|
model,
|
||||||
|
"--dp-size",
|
||||||
|
"2",
|
||||||
|
"--tp-size",
|
||||||
|
"2",
|
||||||
|
"--node-size",
|
||||||
|
"1",
|
||||||
|
"--node-rank",
|
||||||
|
"0",
|
||||||
|
"--trust-remote-code",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running subprocess: {' '.join(cmd)}")
|
||||||
|
proc = subprocess.run(cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=600)
|
||||||
|
output = proc.stdout.decode()
|
||||||
|
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert "DP rank 0 needs to process" in output
|
||||||
|
assert "DP rank 1 needs to process" in output
|
||||||
|
assert "Generated text:" in output
|
||||||
|
assert proc.returncode == 0
|
||||||
@@ -18,32 +18,10 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import vllm.envs as envs_vllm
|
|
||||||
from vllm.config import ParallelConfig
|
|
||||||
|
|
||||||
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type
|
||||||
|
|
||||||
|
|
||||||
def parallel_config_get_dp_port(self) -> int:
|
|
||||||
"""
|
|
||||||
We might need to initialize process groups in multiple
|
|
||||||
processes that is related to data parallelism,
|
|
||||||
e.g. both in the worker and in the engine, which
|
|
||||||
can live in different processes. To avoid port conflicts, we
|
|
||||||
increment the port number each time we need to initialize a
|
|
||||||
new process group related to data parallelism.
|
|
||||||
"""
|
|
||||||
answer = self.data_parallel_master_port
|
|
||||||
self.data_parallel_master_port += 1
|
|
||||||
|
|
||||||
# NOTE: Get port from envs directly when using torchrun
|
|
||||||
port = envs_vllm.VLLM_DP_MASTER_PORT if envs_vllm.VLLM_DP_MASTER_PORT else answer
|
|
||||||
return port
|
|
||||||
|
|
||||||
|
|
||||||
ParallelConfig.get_next_dp_init_port = parallel_config_get_dp_port
|
|
||||||
|
|
||||||
|
|
||||||
class NullHandle:
|
class NullHandle:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
|
|||||||
Reference in New Issue
Block a user