ut: add example and e2e test for sleepmode in external_launcher (#2152)
### What this PR does / why we need it?
This pr add e2e testcase to make sure sleep mode in external_launcher is
ok.
### Does this PR introduce _any_ user-facing change?
not involved
### How was this patch tested?
not involved
- vLLM version: v0.10.0
- vLLM main:
74333ae2f6
Signed-off-by: huangxialu <huangxialu1@huawei.com>
This commit is contained in:
@@ -28,7 +28,7 @@ Single node:
|
|||||||
--proc-per-node=2
|
--proc-per-node=2
|
||||||
MOE models:
|
MOE models:
|
||||||
python examples/offline_external_launcher.py \
|
python examples/offline_external_launcher.py \
|
||||||
--model="Qwen/Qwen3-0.6B" \
|
--model="Qwen/Qwen3-30B-A3B" \
|
||||||
--tp-size=2 \
|
--tp-size=2 \
|
||||||
--proc-per-node=2 \
|
--proc-per-node=2 \
|
||||||
--enable-expert-parallel
|
--enable-expert-parallel
|
||||||
@@ -36,7 +36,7 @@ Single node:
|
|||||||
Multi-node:
|
Multi-node:
|
||||||
Node 0 (assume the node has ip of 10.99.48.128):
|
Node 0 (assume the node has ip of 10.99.48.128):
|
||||||
python examples/offline_external_launcher.py \
|
python examples/offline_external_launcher.py \
|
||||||
--model="Qwen/Qwen3-0.6B" \
|
--model="Qwen/Qwen3-30B-A3B" \
|
||||||
--tp-size=2 \
|
--tp-size=2 \
|
||||||
--node-size=2 \
|
--node-size=2 \
|
||||||
--node-rank=0 \
|
--node-rank=0 \
|
||||||
@@ -46,7 +46,7 @@ Multi-node:
|
|||||||
--master-port=13345
|
--master-port=13345
|
||||||
Node 1:
|
Node 1:
|
||||||
python examples/offline_external_launcher.py \
|
python examples/offline_external_launcher.py \
|
||||||
--model="Qwen/Qwen3-0.6B" \
|
--model="Qwen/Qwen3-30B-A3B" \
|
||||||
--tp-size=2 \
|
--tp-size=2 \
|
||||||
--node-size=2 \
|
--node-size=2 \
|
||||||
--node-rank=1 \
|
--node-rank=1 \
|
||||||
@@ -66,7 +66,7 @@ import torch
|
|||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.distributed.parallel_state import ( # noqa E402
|
from vllm.distributed.parallel_state import ( # noqa E402
|
||||||
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
destroy_distributed_environment, destroy_model_parallel, get_tp_group)
|
||||||
from vllm.utils import get_open_port
|
from vllm.utils import get_open_port, GiB_bytes
|
||||||
|
|
||||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||||
@@ -114,7 +114,28 @@ def parse_args():
|
|||||||
parser.add_argument("--enable-expert-parallel",
|
parser.add_argument("--enable-expert-parallel",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Enable expert parallel, used in MOE models.")
|
help="Enable expert parallel, used in MOE models.")
|
||||||
return parser.parse_args()
|
parser.add_argument("--enable-sleep-mode",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable sleep mode for the engine.")
|
||||||
|
parser.add_argument("--temperature",
|
||||||
|
type=float,
|
||||||
|
default=0.8,
|
||||||
|
help="Float that controls the randomness of the sampling.")
|
||||||
|
parser.add_argument("--model-weight-gib",
|
||||||
|
type=float,
|
||||||
|
default=None,
|
||||||
|
help="Model weight memory usage in GiB (e.g., 1.0 for 0.5B model).")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
if args.enable_sleep_mode:
|
||||||
|
if args.model_weight_gib is None or args.temperature != 0:
|
||||||
|
parser.error("model-weight-gib must be provided, and temperature must be zero when enable-sleep-mode is set.")
|
||||||
|
if args.model_weight_gib <= 0:
|
||||||
|
parser.error("model-weight-gib must be greater than 0 when enable-sleep-mode is set.")
|
||||||
|
if args.model == parser.get_default("model") and args.model_weight_gib is None:
|
||||||
|
parser.error("model-weight-gib must be provided for default model when enable-sleep-mode is set.")
|
||||||
|
|
||||||
|
return args
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
@@ -122,12 +143,15 @@ def main(
|
|||||||
rank: int,
|
rank: int,
|
||||||
master_addr: str,
|
master_addr: str,
|
||||||
master_port: int,
|
master_port: int,
|
||||||
|
model_weight_gib: float,
|
||||||
model: str = "Qwen/Qwen3-0.6B",
|
model: str = "Qwen/Qwen3-0.6B",
|
||||||
world_size: int = 4,
|
world_size: int = 4,
|
||||||
tensor_parallel_size: int = 2,
|
tensor_parallel_size: int = 2,
|
||||||
enable_expert_parallel: bool = False,
|
enable_expert_parallel: bool = False,
|
||||||
enforce_eager: bool = False,
|
enforce_eager: bool = False,
|
||||||
trust_remote_code: bool = True,
|
trust_remote_code: bool = True,
|
||||||
|
enable_sleep_mode: bool = False,
|
||||||
|
temperature: float = 0.8,
|
||||||
):
|
):
|
||||||
os.environ["MASTER_ADDR"] = master_addr
|
os.environ["MASTER_ADDR"] = master_addr
|
||||||
os.environ["MASTER_PORT"] = str(master_port)
|
os.environ["MASTER_PORT"] = str(master_port)
|
||||||
@@ -147,7 +171,7 @@ def main(
|
|||||||
"The future of AI is",
|
"The future of AI is",
|
||||||
] * 10
|
] * 10
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
temperature=0.8,
|
temperature=temperature,
|
||||||
top_p=0.95,
|
top_p=0.95,
|
||||||
max_tokens=10,
|
max_tokens=10,
|
||||||
)
|
)
|
||||||
@@ -159,10 +183,31 @@ def main(
|
|||||||
trust_remote_code=trust_remote_code,
|
trust_remote_code=trust_remote_code,
|
||||||
distributed_executor_backend="external_launcher",
|
distributed_executor_backend="external_launcher",
|
||||||
seed=0,
|
seed=0,
|
||||||
|
enable_sleep_mode=enable_sleep_mode,
|
||||||
)
|
)
|
||||||
tp_ranks = get_tp_group().ranks
|
tp_ranks = get_tp_group().ranks
|
||||||
print(f'TP RANKS: {tp_ranks}')
|
print(f'TP RANKS: {tp_ranks}')
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params)
|
outputs = llm.generate(prompts, sampling_params)
|
||||||
|
|
||||||
|
if enable_sleep_mode:
|
||||||
|
if rank == 0:
|
||||||
|
free_bytes_before_sleep, total = torch.npu.mem_get_info()
|
||||||
|
llm.sleep(level=1)
|
||||||
|
if rank == 0:
|
||||||
|
free_bytes_after_sleep, total = torch.npu.mem_get_info()
|
||||||
|
freed_bytes = free_bytes_after_sleep - free_bytes_before_sleep
|
||||||
|
print(f"Freed memory: {freed_bytes / 1024 ** 3:.2f} GiB")
|
||||||
|
# now the freed memory should be larger than the model weights
|
||||||
|
assert freed_bytes >= model_weight_gib / tensor_parallel_size * GiB_bytes
|
||||||
|
|
||||||
|
llm.wake_up()
|
||||||
|
outputs_after_wakeup = llm.generate(prompts, sampling_params)
|
||||||
|
if rank == 0:
|
||||||
|
# cmp output
|
||||||
|
assert outputs[0].outputs[0].text == outputs_after_wakeup[0].outputs[0].text
|
||||||
|
print("Sleep and wake up successfully!!")
|
||||||
|
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
if i >= 5:
|
if i >= 5:
|
||||||
# print only 5 outputs
|
# print only 5 outputs
|
||||||
@@ -214,12 +259,15 @@ if __name__ == "__main__":
|
|||||||
rank,
|
rank,
|
||||||
master_addr,
|
master_addr,
|
||||||
master_port,
|
master_port,
|
||||||
|
args.model_weight_gib,
|
||||||
args.model,
|
args.model,
|
||||||
world_size,
|
world_size,
|
||||||
tp_size,
|
tp_size,
|
||||||
args.enable_expert_parallel,
|
args.enable_expert_parallel,
|
||||||
args.enforce_eager,
|
args.enforce_eager,
|
||||||
args.trust_remote_code,
|
args.trust_remote_code,
|
||||||
|
args.enable_sleep_mode,
|
||||||
|
args.temperature,
|
||||||
))
|
))
|
||||||
|
|
||||||
proc.start()
|
proc.start()
|
||||||
|
|||||||
@@ -24,15 +24,14 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
MODELS = ["Qwen/Qwen3-0.6B"]
|
MODELS = ["Qwen/Qwen3-0.6B"]
|
||||||
|
MOE_MODELS = ["Qwen/Qwen3-30B-A3B"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model", MODELS)
|
@pytest.mark.parametrize("model", MODELS)
|
||||||
@patch.dict(os.environ, {"ASCEND_RT_VISIBLE_DEVICES": "0,1,2,3"})
|
|
||||||
def test_external_launcher(model):
|
def test_external_launcher(model):
|
||||||
script = Path(
|
script = Path(
|
||||||
__file__
|
__file__
|
||||||
@@ -71,3 +70,80 @@ def test_external_launcher(model):
|
|||||||
assert "TP RANKS: [1]" in output
|
assert "TP RANKS: [1]" in output
|
||||||
assert "Generated text:" in output
|
assert "Generated text:" in output
|
||||||
assert proc.returncode == 0
|
assert proc.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", MOE_MODELS)
|
||||||
|
def test_moe_external_launcher(model):
|
||||||
|
script = Path(
|
||||||
|
__file__
|
||||||
|
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
|
||||||
|
env = os.environ.copy()
|
||||||
|
# TODO: Change to 2 when ci machine has 4 cards
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(script), "--model", model, "--tp-size", "2", "--node-size", "1",
|
||||||
|
"--node-rank", "0", "--proc-per-node", "2", "--trust-remote-code",
|
||||||
|
"--enable-expert-parallel"
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running subprocess: {' '.join(cmd)}")
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=600,
|
||||||
|
)
|
||||||
|
output = proc.stdout.decode()
|
||||||
|
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert "TP RANKS: [0, 1]" in output
|
||||||
|
assert "Generated text:" in output
|
||||||
|
assert proc.returncode == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_external_launcher_and_sleepmode():
|
||||||
|
script = Path(
|
||||||
|
__file__
|
||||||
|
).parent.parent.parent.parent / "examples" / "offline_external_launcher.py"
|
||||||
|
env = os.environ.copy()
|
||||||
|
# TODO: Change to 2 when ci machine has 4 cards
|
||||||
|
cmd = [
|
||||||
|
sys.executable,
|
||||||
|
str(script),
|
||||||
|
"--model",
|
||||||
|
"Qwen/Qwen3-8B",
|
||||||
|
"--tp-size",
|
||||||
|
"1",
|
||||||
|
"--node-size",
|
||||||
|
"1",
|
||||||
|
"--node-rank",
|
||||||
|
"0",
|
||||||
|
"--proc-per-node",
|
||||||
|
"2",
|
||||||
|
"--trust-remote-code",
|
||||||
|
"--enable-sleep-mode",
|
||||||
|
"--temperature",
|
||||||
|
"0",
|
||||||
|
"--model-weight-gib",
|
||||||
|
"16",
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Running subprocess: {' '.join(cmd)}")
|
||||||
|
proc = subprocess.run(
|
||||||
|
cmd,
|
||||||
|
env=env,
|
||||||
|
stdout=subprocess.PIPE,
|
||||||
|
stderr=subprocess.STDOUT,
|
||||||
|
timeout=300,
|
||||||
|
)
|
||||||
|
output = proc.stdout.decode()
|
||||||
|
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
assert "TP RANKS: [0]" in output
|
||||||
|
assert "TP RANKS: [1]" in output
|
||||||
|
assert "Generated text:" in output
|
||||||
|
assert "Sleep and wake up successfully!!" in output
|
||||||
|
assert proc.returncode == 0
|
||||||
|
|||||||
Reference in New Issue
Block a user