### What this PR does / why we need it?
This PR refactors the communication group of MC2 to keep it consistent
with vllm's EP group, making it compatible with PP.
- vLLM version: v0.17.0
- vLLM main:
4034c3d32e
---------
Signed-off-by: QiuChunshuo <qiuchunshuo@huawei.com>
1524 lines
56 KiB
Python
1524 lines
56 KiB
Python
#
|
||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||
# Copyright 2023 The vLLM team.
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
# This file is a part of the vllm-ascend project.
|
||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||
#
|
||
|
||
import contextlib
|
||
import copy
|
||
import functools
|
||
import gc
|
||
import json
|
||
import logging
|
||
import multiprocessing
|
||
import os
|
||
import shlex
|
||
import subprocess
|
||
import sys
|
||
import threading
|
||
import time
|
||
import traceback
|
||
from pathlib import Path
|
||
from typing import Any, TypeVar
|
||
|
||
import huggingface_hub
|
||
import numpy as np
|
||
import openai
|
||
import psutil
|
||
import pytest
|
||
import requests
|
||
import torch
|
||
from modelscope import snapshot_download # type: ignore[import-untyped]
|
||
from PIL import Image
|
||
from requests.exceptions import RequestException
|
||
from torch import nn
|
||
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature
|
||
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
||
from vllm import LLM, SamplingParams
|
||
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
||
from vllm.inputs import TextPrompt
|
||
from vllm.outputs import RequestOutput
|
||
from vllm.platforms import current_platform
|
||
from vllm.transformers_utils.utils import maybe_model_redirect
|
||
from vllm.utils.network_utils import get_open_port
|
||
|
||
from tests.e2e.model_utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
||
from tests.e2e.nightly.multi_node.scripts.multi_node_config import DisaggregatedPrefillCfg, NodeInfo
|
||
from vllm_ascend.ascend_config import clear_ascend_config
|
||
|
||
# TODO: remove this part after the patch merged into vllm, if
|
||
# we not explicitly patch here, some of them might be effectiveless
|
||
# in pytest scenario
|
||
from vllm_ascend.utils import adapt_patch # noqa E402
|
||
|
||
adapt_patch(True)
|
||
adapt_patch(False)
|
||
|
||
from vllm.distributed.parallel_state import ( # noqa E402
|
||
destroy_distributed_environment,
|
||
destroy_model_parallel,
|
||
)
|
||
|
||
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
||
_M = TypeVar("_M")
|
||
|
||
_PromptMultiModalInput = list[_M] | list[list[_M]]
|
||
|
||
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
||
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
|
||
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_TEST_DIR = os.path.dirname(__file__)
|
||
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")]
|
||
|
||
DISAGG_EPD_PROXY_SCRIPT = (
|
||
Path(__file__).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py"
|
||
)
|
||
|
||
|
||
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
|
||
# We can try to clean up memory in this subprocess, though it mostly affects this process.
|
||
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
|
||
gc.collect()
|
||
torch.npu.empty_cache()
|
||
|
||
_, total_npu_memory = torch.npu.mem_get_info()
|
||
start_time = time.time()
|
||
|
||
while True:
|
||
free_bytes, _ = torch.npu.mem_get_info()
|
||
if free_bytes / total_npu_memory >= target_free_percentage:
|
||
print("check_npu_memory_worker: npu free memory decreased target value.")
|
||
return # Success
|
||
|
||
elapsed = time.time() - start_time
|
||
if elapsed > max_wait_seconds:
|
||
# Print to stderr so it's visible in test logs even if captured
|
||
print(
|
||
f"Timeout: NPU memory free size did not reach "
|
||
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
|
||
file=sys.stderr,
|
||
)
|
||
sys.exit(1) # Failure
|
||
|
||
print(
|
||
f"Waiting for NPU memory to be free: "
|
||
f"{free_bytes / 1024**3:.2f} GB available, "
|
||
f"Elapsed time: {elapsed:.2f} s."
|
||
)
|
||
# Try to clean up
|
||
gc.collect()
|
||
torch.npu.empty_cache()
|
||
time.sleep(1)
|
||
|
||
|
||
def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
|
||
"""Decorator to wait until the NPU memory free size is above target_free_percentage.
|
||
|
||
Args:
|
||
target_free_percentage (float): Target free memory percentage of total.
|
||
max_wait_seconds (float): Maximum wait time in seconds.
|
||
"""
|
||
|
||
def decorator(func):
|
||
@functools.wraps(func)
|
||
def wrapper(*args, **kwargs):
|
||
# Clean up non-NPU resources in the main process
|
||
cleanup_dist_env_and_memory()
|
||
|
||
# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
|
||
ctx = multiprocessing.get_context("spawn")
|
||
p = ctx.Process(target=_check_npu_memory_worker, args=(target_free_percentage, max_wait_seconds))
|
||
p.start()
|
||
p.join()
|
||
|
||
if p.exitcode != 0:
|
||
raise TimeoutError(
|
||
f"Timeout: NPU memory free size did not reach "
|
||
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
|
||
)
|
||
|
||
return func(*args, **kwargs)
|
||
|
||
return wrapper
|
||
|
||
return decorator
|
||
|
||
|
||
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
||
destroy_model_parallel()
|
||
destroy_distributed_environment()
|
||
with contextlib.suppress(AssertionError):
|
||
torch.distributed.destroy_process_group()
|
||
if shutdown_ray:
|
||
import ray # Lazy import Ray
|
||
|
||
ray.shutdown()
|
||
gc.collect()
|
||
|
||
# Only clean NPU cache if NPU is already initialized/available in this process.
|
||
# This prevents accidental initialization of NPU context in the main process,
|
||
# which would break subsequent forks.
|
||
if hasattr(torch, "npu") and torch.npu.is_initialized():
|
||
torch.npu.empty_cache()
|
||
torch.npu.reset_peak_memory_stats()
|
||
|
||
|
||
class MooncakeLauncher:
|
||
def __init__(
|
||
self,
|
||
mooncake_port,
|
||
mooncake_metrics_port,
|
||
eviction_high_watermark_ratio=0.8,
|
||
eviction_ratio=0.05,
|
||
):
|
||
self.mooncake_port = mooncake_port
|
||
self.mooncake_metrics_port = mooncake_metrics_port
|
||
self.eviction_high_watermark_ratio = eviction_high_watermark_ratio
|
||
self.eviction_ratio = eviction_ratio
|
||
|
||
def __enter__(self):
|
||
cmd = [
|
||
"mooncake_master",
|
||
"--eviction_high_watermark_ratio",
|
||
str(self.eviction_high_watermark_ratio),
|
||
"--eviction_ratio",
|
||
str(self.eviction_ratio),
|
||
"--port",
|
||
str(self.mooncake_port),
|
||
"--metrics_port",
|
||
str(self.mooncake_metrics_port),
|
||
]
|
||
|
||
logger.info("Launching mooncake: %s", " ".join(cmd))
|
||
curr_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
|
||
mooncake_ld_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:"
|
||
os.environ["LD_LIBRARY_PATH"] = mooncake_ld_path + curr_ld_path
|
||
env = os.environ.copy()
|
||
self.process = subprocess.Popen(cmd, env=env)
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc, tb):
|
||
if not self.process:
|
||
return
|
||
logger.info("Stopping mooncake server...")
|
||
self.process.terminate()
|
||
try:
|
||
self.process.wait(timeout=5)
|
||
except subprocess.TimeoutExpired:
|
||
self.process.kill()
|
||
|
||
|
||
class RemoteOpenAIServer:
|
||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
||
|
||
def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
|
||
"""Subclasses override this method to customize server process launch"""
|
||
env = os.environ.copy()
|
||
# the current process might initialize npu,
|
||
# to be safe, we should use spawn method
|
||
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||
if env_dict is not None:
|
||
env.update(env_dict)
|
||
logger.info(f"Starting server with command: {' '.join(server_cmd)}")
|
||
self.proc: subprocess.Popen = subprocess.Popen(
|
||
server_cmd,
|
||
env=env,
|
||
stdout=sys.stdout,
|
||
stderr=sys.stderr,
|
||
)
|
||
|
||
def __init__(
|
||
self,
|
||
model: str,
|
||
vllm_serve_args: list[str] | str,
|
||
*,
|
||
server_host: str = "0.0.0.0",
|
||
server_port: int = 8080,
|
||
env_dict: dict[str, str] | None = None,
|
||
seed: int | None = None,
|
||
auto_port: bool = True,
|
||
nodes_info: list[NodeInfo] | None = None,
|
||
disaggregated_prefill: DisaggregatedPrefillCfg | None = None,
|
||
proxy_port: int | None = None,
|
||
max_wait_seconds: float | None = None,
|
||
override_hf_configs: dict[str, Any] | None = None,
|
||
) -> None:
|
||
if isinstance(vllm_serve_args, str):
|
||
vllm_serve_args = shlex.split(vllm_serve_args)
|
||
else:
|
||
vllm_serve_args = ["vllm", "serve", model, *vllm_serve_args]
|
||
if auto_port:
|
||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
||
raise ValueError("You have manually specified the port when `auto_port=True`.")
|
||
|
||
# No need for a port if using unix sockets
|
||
if "--uds" not in vllm_serve_args:
|
||
# Don't mutate the input args
|
||
vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
|
||
if seed is not None:
|
||
if "--seed" in vllm_serve_args:
|
||
raise ValueError(f"You have manually specified the seed when `seed={seed}`.")
|
||
|
||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
||
|
||
if override_hf_configs is not None:
|
||
vllm_serve_args = vllm_serve_args + ["--hf-overrides", json.dumps(override_hf_configs)]
|
||
|
||
self.host = str(server_host)
|
||
self.port = int(server_port)
|
||
# for multi-nodes test
|
||
self.nodes_info = nodes_info
|
||
self.disaggregated_prefill = disaggregated_prefill
|
||
self.cur_index = os.getenv("LWS_WORKER_INDEX", 0)
|
||
self.proxy_port = proxy_port
|
||
|
||
self._start_server(model, vllm_serve_args, env_dict)
|
||
max_wait_seconds = max_wait_seconds or 2800
|
||
if self.disaggregated_prefill:
|
||
assert proxy_port is not None, "for disaggregated_prefill, proxy port must be provided"
|
||
self._wait_for_server_pd(timeout=max_wait_seconds)
|
||
else:
|
||
self._wait_for_multiple_servers([(self.host, self.url_for("health"))], timeout=max_wait_seconds)
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
self._terminate_server()
|
||
|
||
def _poll(self) -> int | None:
|
||
"""Subclasses override this method to customize process polling"""
|
||
return self.proc.poll()
|
||
|
||
def hang_until_terminated(self, url) -> None:
|
||
"""
|
||
Wait until the server process terminates.
|
||
This is for headless mode, where the api server
|
||
process only exists in the leader node.
|
||
"""
|
||
logger.info("Hanging until server process terminates...")
|
||
client = requests
|
||
try:
|
||
while True:
|
||
try:
|
||
resp = client.get(url, timeout=5)
|
||
if resp.status_code != 200:
|
||
break
|
||
time.sleep(5)
|
||
except Exception:
|
||
break
|
||
finally:
|
||
self._terminate_server()
|
||
|
||
def _wait_for_server_pd(self, timeout: float):
|
||
# Wait for all api_server nodes ready
|
||
assert self.nodes_info is not None, "cluster info must be provided"
|
||
proxy_port = self.proxy_port
|
||
|
||
def url_health(ip: str, port: int) -> str:
|
||
return f"http://{ip}:{port}/health"
|
||
|
||
targets = [
|
||
(node_info.ip, url_health(node_info.ip, self.port))
|
||
for node_info in self.nodes_info
|
||
if not node_info.headless
|
||
]
|
||
|
||
# Wait for proxy ready
|
||
master_node = self.nodes_info[0]
|
||
url_proxy = f"http://{master_node.ip}:{proxy_port}/healthcheck"
|
||
|
||
# Wait for master node proxy first
|
||
self._wait_for_multiple_servers([(master_node.ip, url_proxy)], timeout=timeout)
|
||
|
||
# Then wait for all api_server nodes
|
||
self._wait_for_multiple_servers(targets=targets, timeout=timeout)
|
||
|
||
def _wait_for_multiple_servers(
|
||
self, targets, timeout: float, log_interval: float = 30.0, always_check_nodes: bool = False
|
||
):
|
||
"""
|
||
targets: List[(node_ip, url)]
|
||
log_interval
|
||
"""
|
||
start = time.time()
|
||
client = requests
|
||
|
||
ready = {node_ip: False for node_ip, _ in targets}
|
||
|
||
last_log_time = 0.0
|
||
|
||
while True:
|
||
now = time.time()
|
||
all_ready = True
|
||
should_log = (now - last_log_time) >= log_interval
|
||
|
||
for node_ip, url in targets:
|
||
if ready[node_ip] and not always_check_nodes:
|
||
continue
|
||
|
||
try:
|
||
resp = client.get(url)
|
||
if resp.status_code == 200:
|
||
ready[node_ip] = True
|
||
logger.info(f"[READY] Node {node_ip}: {url} is ready.")
|
||
except RequestException:
|
||
all_ready = False
|
||
if should_log:
|
||
logger.debug(f"[WAIT] {url}: connection failed")
|
||
|
||
# check unexpected exit
|
||
result = self._poll()
|
||
if result is not None and result != 0:
|
||
raise RuntimeError(f"Server at {node_ip} exited unexpectedly.") from None
|
||
|
||
if should_log:
|
||
last_log_time = now
|
||
|
||
if all_ready:
|
||
break
|
||
|
||
if now - start > timeout:
|
||
not_ready_nodes = [n for n, ok in ready.items() if not ok]
|
||
self._terminate_server()
|
||
raise RuntimeError(
|
||
f"Timeout: these nodes did not become ready: {not_ready_nodes} in time: {timeout}s"
|
||
) from None
|
||
|
||
time.sleep(5)
|
||
|
||
@property
|
||
def url_root(self) -> str:
|
||
return f"http://{self.host}:{self.port}"
|
||
|
||
def _terminate_server(self) -> None:
|
||
"""Subclasses override this method to customize server process termination"""
|
||
self.proc.terminate()
|
||
try:
|
||
self.proc.wait(8)
|
||
except subprocess.TimeoutExpired:
|
||
# force kill if needed
|
||
self.proc.kill()
|
||
|
||
def url_for(self, *parts: str) -> str:
|
||
return self.url_root + "/" + "/".join(parts)
|
||
|
||
def get_client(self, **kwargs):
|
||
if "timeout" not in kwargs:
|
||
kwargs["timeout"] = 600
|
||
return openai.OpenAI(
|
||
base_url=self.url_for("v1"),
|
||
api_key=self.DUMMY_API_KEY,
|
||
max_retries=0,
|
||
**kwargs,
|
||
)
|
||
|
||
def get_async_client(self, **kwargs):
|
||
if "timeout" not in kwargs:
|
||
kwargs["timeout"] = 600
|
||
return openai.AsyncOpenAI(base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs)
|
||
|
||
|
||
class RemoteEPDServer(RemoteOpenAIServer):
|
||
def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
|
||
"""Subclasses override this method to customize server process launch"""
|
||
raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead")
|
||
|
||
def __init__(
|
||
self,
|
||
vllm_serve_args: list[str] | list[list[str]],
|
||
server_host: str = "0.0.0.0",
|
||
env_dict: dict[str, str] | None = None,
|
||
max_wait_seconds: float | None = 2800,
|
||
) -> None:
|
||
self._proc_list = []
|
||
|
||
self.env_dict: dict[str, str] = {}
|
||
if env_dict is not None:
|
||
self.env_dict.update(env_dict)
|
||
|
||
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
||
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
|
||
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||
|
||
self.vllm_serve_args_list = []
|
||
self.health_url_list = []
|
||
self.host = server_host
|
||
|
||
if isinstance(vllm_serve_args, list):
|
||
if not all(isinstance(item, list) for item in vllm_serve_args):
|
||
args_copy = copy.deepcopy(vllm_serve_args)
|
||
self.vllm_serve_args_list.append([str(arg) for arg in args_copy])
|
||
else:
|
||
self.vllm_serve_args_list = [
|
||
[str(arg) for arg in sublist] for sublist in copy.deepcopy(vllm_serve_args)
|
||
]
|
||
else:
|
||
raise RuntimeError("vllm_serves_args must be a list")
|
||
|
||
serve_arg_cmd = ["vllm", "serve"]
|
||
|
||
for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
|
||
self.env_dict["ASCEND_RT_VISIBLE_DEVICES"] = str(i)
|
||
if isinstance(vllm_serve_arg, list):
|
||
if "--port" not in vllm_serve_arg:
|
||
raise ValueError("You have manually specified the port ")
|
||
else:
|
||
port_arg = "--port"
|
||
try:
|
||
index = vllm_serve_arg.index(port_arg)
|
||
except ValueError:
|
||
raise ValueError(f"--port not found in args: {vllm_serve_arg}")
|
||
port_str = vllm_serve_arg[index + 1]
|
||
self.port = int(port_str)
|
||
else:
|
||
vllm_serve_arg_str = str(vllm_serve_arg)
|
||
if "--port" not in vllm_serve_arg_str:
|
||
raise ValueError("You have manually specified the port ")
|
||
else:
|
||
raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}")
|
||
|
||
self.health_url_list.append(super().url_for("health"))
|
||
vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
|
||
proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict, f"[VLLM_{i}] ")
|
||
self._proc_list.append(proc)
|
||
|
||
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
||
super()._wait_for_multiple_servers(
|
||
[(self.host, url) for url in self.health_url_list], timeout=timeout_value, always_check_nodes=True
|
||
)
|
||
|
||
def _poll(self) -> int | None:
|
||
return None
|
||
|
||
def _delete_shm(self) -> None:
|
||
for i, arg in enumerate(self.vllm_serve_args_list):
|
||
if "--ec-transfer-config" in arg:
|
||
index = arg.index("--ec-transfer-config")
|
||
config_str = arg[index + 1]
|
||
config_dict = json.loads(config_str)
|
||
ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {})
|
||
shm_path = ec_connector_extra_config.get("shared_storage_path")
|
||
if shm_path:
|
||
args = ["rm", "-r", "-f", str(shm_path)]
|
||
print(f"delete shm_path is: {shm_path}")
|
||
self._start_server_with_prefix(args, None, "[DELETE] ")
|
||
|
||
def _read_output(self, pipe, prefix):
|
||
try:
|
||
with pipe:
|
||
for line in iter(pipe.readline, ""):
|
||
if line:
|
||
print(f"{prefix}: {line}", end="")
|
||
|
||
except Exception as e:
|
||
print(f"error: {e}")
|
||
traceback.print_exc()
|
||
|
||
def _start_server_with_prefix(self, server_cmd: list[str], env_dict: dict[str, str] | None, log_prefix: str):
|
||
env = os.environ.copy()
|
||
if env_dict is not None:
|
||
env.update(env_dict)
|
||
proc = subprocess.Popen(
|
||
server_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, bufsize=1
|
||
)
|
||
stdout_thread = threading.Thread(target=self._read_output, args=(proc.stdout, log_prefix), daemon=True)
|
||
stderr_thread = threading.Thread(target=self._read_output, args=(proc.stderr, log_prefix), daemon=True)
|
||
|
||
stdout_thread.start()
|
||
stderr_thread.start()
|
||
return proc
|
||
|
||
def _terminate_server(self) -> None:
|
||
"""kill process and its children"""
|
||
print("vllm instance is stopping")
|
||
for proc in self._proc_list:
|
||
parent = psutil.Process(proc.pid)
|
||
children = parent.children(recursive=True)
|
||
for child in children:
|
||
with contextlib.suppress(psutil.NoSuchProcess):
|
||
child.terminate()
|
||
|
||
gone, still_alive = psutil.wait_procs(children, timeout=10)
|
||
|
||
for child in still_alive:
|
||
with contextlib.suppress(psutil.NoSuchProcess):
|
||
child.kill()
|
||
|
||
try:
|
||
parent.terminate()
|
||
parent.wait(timeout=10)
|
||
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
||
with contextlib.suppress(psutil.NoSuchProcess):
|
||
parent.kill()
|
||
|
||
def __enter__(self):
|
||
"""Context manager entry point."""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""Context manager exit point - clean up all processes."""
|
||
self._terminate_server()
|
||
|
||
|
||
class DisaggEpdProxy(RemoteEPDServer):
|
||
def __init__(
|
||
self,
|
||
proxy_args: list[str] | str | None = None,
|
||
env_dict: dict[str, str] | None = None,
|
||
server_host: str = "0.0.0.0",
|
||
max_wait_seconds: float | None = 2800,
|
||
) -> None:
|
||
if proxy_args is None:
|
||
proxy_args_list: list[str] = []
|
||
elif isinstance(proxy_args, str):
|
||
proxy_args_list = shlex.split(proxy_args)
|
||
else:
|
||
proxy_args_list = proxy_args
|
||
|
||
self.proxy_args = proxy_args_list
|
||
self.env_dict: dict[str, str] = {}
|
||
if env_dict is not None:
|
||
self.env_dict.update(env_dict)
|
||
self._proc_list = list()
|
||
self.host = server_host
|
||
|
||
print(f"proxy param is: {self.proxy_args}")
|
||
proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args]
|
||
proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ")
|
||
self._proc_list.append(proc)
|
||
|
||
if "--port" not in self.proxy_args:
|
||
raise ValueError("You have manually specified the port ")
|
||
else:
|
||
try:
|
||
index = self.proxy_args.index("--port")
|
||
except ValueError:
|
||
raise ValueError("--port not found in proxy args")
|
||
port_str = self.proxy_args[index + 1]
|
||
self.port = int(port_str)
|
||
|
||
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
||
super()._wait_for_multiple_servers([(self.host, super().url_for("health"))], timeout=timeout_value)
|
||
|
||
def __enter__(self):
|
||
"""Context manager entry point."""
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||
"""Context manager exit point - clean up all processes."""
|
||
super()._terminate_server()
|
||
|
||
|
||
_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
|
||
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
|
||
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0
|
||
|
||
|
||
def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
|
||
if num_items < 0:
|
||
raise ValueError("num_items must be non-negative")
|
||
if dp_size <= 0:
|
||
raise ValueError("dp_size must be positive")
|
||
|
||
floor = num_items // dp_size
|
||
remainder = num_items % dp_size
|
||
|
||
def start(rank: int) -> int:
|
||
return rank * floor + min(rank, remainder)
|
||
|
||
return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]
|
||
|
||
|
||
def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
|
||
if inputs is None:
|
||
return None
|
||
return [inputs[index] for index in indices]
|
||
|
||
|
||
def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
|
||
return [items[index] for index in indices]
|
||
|
||
|
||
def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
|
||
merged: list[Any] = [None] * total_items
|
||
for indices, results in shard_results:
|
||
if not indices:
|
||
continue
|
||
if len(indices) != len(results):
|
||
raise RuntimeError("Mismatched result count returned by data parallel worker")
|
||
for index, result in zip(indices, results):
|
||
merged[index] = result
|
||
|
||
if any(result is None for result in merged):
|
||
raise RuntimeError("Some data parallel results were not returned")
|
||
|
||
return merged
|
||
|
||
|
||
def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
|
||
if isinstance(text_1, str) and isinstance(text_2, str):
|
||
return [text_1], [text_2]
|
||
if isinstance(text_1, str):
|
||
return [text_1] * len(text_2), list(text_2)
|
||
if isinstance(text_2, str):
|
||
return list(text_1), [text_2] * len(text_1)
|
||
if len(text_1) != len(text_2):
|
||
raise ValueError("`text_1` and `text_2` must have the same length")
|
||
return list(text_1), list(text_2)
|
||
|
||
|
||
def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
|
||
llm = None
|
||
try:
|
||
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
||
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
|
||
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
||
os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
|
||
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||
|
||
llm = LLM(**llm_kwargs)
|
||
conn.send({"status": "ready", "rank": dp_rank})
|
||
|
||
while True:
|
||
request = conn.recv()
|
||
command = request["command"]
|
||
if command == "shutdown":
|
||
break
|
||
|
||
result: Any
|
||
if command == "generate":
|
||
req_outputs = llm.generate(
|
||
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||
)
|
||
result = VllmRunner._finalize_generate_outputs(req_outputs)
|
||
elif command == "generate_w_logprobs":
|
||
req_outputs = llm.generate(
|
||
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
||
)
|
||
result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
|
||
elif command == "classify":
|
||
req_outputs = llm.classify(request["prompts"])
|
||
result = [req_output.outputs.probs for req_output in req_outputs]
|
||
elif command == "embed":
|
||
req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
|
||
result = [req_output.outputs.embedding for req_output in req_outputs]
|
||
elif command == "encode":
|
||
req_outputs = llm.encode(request["prompts"])
|
||
result = [req_output.outputs.data for req_output in req_outputs]
|
||
elif command == "reward":
|
||
req_outputs = llm.reward(request["prompts"])
|
||
result = [req_output.outputs.data for req_output in req_outputs]
|
||
elif command == "score":
|
||
req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
|
||
result = [req_output.outputs.score for req_output in req_outputs]
|
||
else:
|
||
raise ValueError(f"Unsupported data parallel command: {command}")
|
||
|
||
conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
|
||
except Exception:
|
||
with contextlib.suppress(Exception):
|
||
conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
|
||
raise
|
||
finally:
|
||
if llm is not None:
|
||
del llm
|
||
clear_ascend_config()
|
||
cleanup_dist_env_and_memory()
|
||
with contextlib.suppress(Exception):
|
||
conn.close()
|
||
|
||
|
||
class VllmRunner:
|
||
def __init__(
|
||
self,
|
||
model_name: str,
|
||
runner: RunnerOption = "auto",
|
||
convert: ConvertOption = "auto",
|
||
tokenizer_name: str | None = None,
|
||
tokenizer_mode: str = "auto",
|
||
max_model_len: int | None = 1024,
|
||
dtype: str = "auto",
|
||
disable_log_stats: bool = True,
|
||
tensor_parallel_size: int = 1,
|
||
block_size: int = 16,
|
||
enable_chunked_prefill: bool = True,
|
||
swap_space: int = 4,
|
||
enforce_eager: bool | None = False,
|
||
quantization: str | None = None,
|
||
**kwargs,
|
||
) -> None:
|
||
data_parallel_size = int(kwargs.get("data_parallel_size", 1))
|
||
if data_parallel_size > 1:
|
||
raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")
|
||
|
||
self.model = LLM(
|
||
model=model_name,
|
||
runner=runner,
|
||
convert=convert,
|
||
tokenizer=tokenizer_name,
|
||
tokenizer_mode=tokenizer_mode,
|
||
trust_remote_code=True,
|
||
dtype=dtype,
|
||
swap_space=swap_space,
|
||
enforce_eager=enforce_eager,
|
||
disable_log_stats=disable_log_stats,
|
||
tensor_parallel_size=tensor_parallel_size,
|
||
max_model_len=max_model_len,
|
||
block_size=block_size,
|
||
enable_chunked_prefill=enable_chunked_prefill,
|
||
quantization=quantization,
|
||
**kwargs,
|
||
)
|
||
|
||
@staticmethod
|
||
def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
|
||
outputs: list[tuple[list[list[int]], list[str]]] = []
|
||
for req_output in req_outputs:
|
||
prompt_str = req_output.prompt
|
||
prompt_ids = req_output.prompt_token_ids
|
||
req_sample_output_ids: list[list[int]] = []
|
||
req_sample_output_strs: list[str] = []
|
||
for sample in req_output.outputs:
|
||
output_str = sample.text
|
||
output_ids = list(sample.token_ids)
|
||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||
req_sample_output_strs.append((prompt_str or "") + output_str)
|
||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||
return outputs
|
||
|
||
def get_inputs(
|
||
self,
|
||
prompts: list[str] | list[torch.Tensor] | list[int],
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
) -> list[TextPrompt]:
|
||
if any(x is not None and len(x) != len(prompts) for x in [images, videos, audios]):
|
||
raise ValueError("All non-None multimodal inputs must have the same length as prompts")
|
||
|
||
inputs = []
|
||
for i, prompt in enumerate(prompts):
|
||
multi_modal_data = {}
|
||
if images is not None and (image := images[i]) is not None:
|
||
multi_modal_data["image"] = image
|
||
if videos is not None and (video := videos[i]) is not None:
|
||
multi_modal_data["video"] = video # type: ignore
|
||
if audios is not None and (audio := audios[i]) is not None:
|
||
multi_modal_data["audio"] = audio # type: ignore
|
||
|
||
text_prompt_kwargs: dict[str, Any] = {"multi_modal_data": multi_modal_data or None}
|
||
if isinstance(prompt, str):
|
||
text_prompt_kwargs["prompt"] = prompt
|
||
elif isinstance(prompt, list):
|
||
text_prompt_kwargs["prompt_token_ids"] = prompt
|
||
else:
|
||
text_prompt_kwargs["prompt_embeds"] = prompt
|
||
|
||
inputs.append(TextPrompt(**text_prompt_kwargs))
|
||
|
||
return inputs
|
||
|
||
def generate(
|
||
self,
|
||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||
sampling_params: SamplingParams,
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
**kwargs: Any,
|
||
) -> list[tuple[list[list[int]], list[str]]]:
|
||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
||
return self._finalize_generate_outputs(req_outputs)
|
||
|
||
@staticmethod
|
||
def _final_steps_generate_w_logprobs(
|
||
req_outputs: list[RequestOutput],
|
||
) -> list[TokensTextLogprobsPromptLogprobs]:
|
||
outputs: list[TokensTextLogprobsPromptLogprobs] = []
|
||
for req_output in req_outputs:
|
||
assert len(req_output.outputs) > 0
|
||
for sample in req_output.outputs:
|
||
output_str = sample.text
|
||
output_ids = list(sample.token_ids)
|
||
output_logprobs = sample.logprobs
|
||
outputs.append((output_ids, output_str, output_logprobs, req_output.prompt_logprobs))
|
||
return outputs
|
||
|
||
def generate_w_logprobs(
|
||
self,
|
||
prompts: list[str],
|
||
sampling_params: SamplingParams,
|
||
images: PromptImageInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
**kwargs: Any,
|
||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||
|
||
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
||
|
||
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(req_outputs)
|
||
# Omit prompt logprobs if not required by sampling params
|
||
return (
|
||
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||
if sampling_params.prompt_logprobs is None
|
||
else toks_str_logsprobs_prompt_logprobs
|
||
)
|
||
|
||
def generate_greedy(
|
||
self,
|
||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||
max_tokens: int,
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
**kwargs: Any,
|
||
) -> list[tuple[list[int], str]]:
|
||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||
outputs = self.generate(prompts, greedy_params, images=images, videos=videos, audios=audios, **kwargs)
|
||
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
||
|
||
def generate_greedy_logprobs(
|
||
self,
|
||
prompts: list[str],
|
||
max_tokens: int,
|
||
num_logprobs: int | None,
|
||
num_prompt_logprobs: int | None = None,
|
||
images: PromptImageInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
stop_token_ids: list[int] | None = None,
|
||
stop: list[str] | None = None,
|
||
**kwargs: Any,
|
||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||
greedy_logprobs_params = SamplingParams(
|
||
temperature=0.0,
|
||
max_tokens=max_tokens,
|
||
logprobs=num_logprobs,
|
||
prompt_logprobs=num_prompt_logprobs,
|
||
stop_token_ids=stop_token_ids,
|
||
stop=stop,
|
||
)
|
||
|
||
return self.generate_w_logprobs(
|
||
prompts, greedy_logprobs_params, images=images, audios=audios, videos=videos, **kwargs
|
||
)
|
||
|
||
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||
req_outputs = self.model.classify(prompts)
|
||
return [req_output.outputs.probs for req_output in req_outputs]
|
||
|
||
def embed(
|
||
self,
|
||
prompts: list[str],
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
*args,
|
||
**kwargs,
|
||
) -> list[list[float]]:
|
||
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
||
|
||
req_outputs = self.model.embed(inputs, *args, **kwargs)
|
||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||
|
||
def encode(self, prompts: list[str]) -> list[list[float]]:
|
||
req_outputs = self.model.encode(prompts)
|
||
return [req_output.outputs.data for req_output in req_outputs]
|
||
|
||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||
req_outputs = self.model.reward(prompts)
|
||
return [req_output.outputs.data for req_output in req_outputs]
|
||
|
||
def score(
|
||
self,
|
||
text_1: str | list[str],
|
||
text_2: str | list[str],
|
||
*args,
|
||
**kwargs,
|
||
) -> list[float]:
|
||
req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
|
||
return [req_output.outputs.score for req_output in req_outputs]
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
del self.model
|
||
clear_ascend_config()
|
||
cleanup_dist_env_and_memory()
|
||
|
||
|
||
class DPVllmRunner(VllmRunner):
|
||
def __init__(
|
||
self,
|
||
model_name: str,
|
||
runner: RunnerOption = "auto",
|
||
convert: ConvertOption = "auto",
|
||
tokenizer_name: str | None = None,
|
||
tokenizer_mode: str = "auto",
|
||
max_model_len: int | None = 1024,
|
||
dtype: str = "auto",
|
||
disable_log_stats: bool = True,
|
||
tensor_parallel_size: int = 1,
|
||
block_size: int = 16,
|
||
enable_chunked_prefill: bool = True,
|
||
swap_space: int = 4,
|
||
enforce_eager: bool | None = False,
|
||
quantization: str | None = None,
|
||
data_parallel_size: int = 2,
|
||
**kwargs,
|
||
) -> None:
|
||
if data_parallel_size < 2:
|
||
raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")
|
||
|
||
self._dp_size = data_parallel_size
|
||
self._dp_parent_conns: list[Any] = []
|
||
self._dp_processes: list[Any] = []
|
||
self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
|
||
self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))
|
||
|
||
llm_kwargs = dict(
|
||
model=model_name,
|
||
runner=runner,
|
||
convert=convert,
|
||
tokenizer=tokenizer_name,
|
||
tokenizer_mode=tokenizer_mode,
|
||
trust_remote_code=True,
|
||
dtype=dtype,
|
||
swap_space=swap_space,
|
||
enforce_eager=enforce_eager,
|
||
disable_log_stats=disable_log_stats,
|
||
tensor_parallel_size=tensor_parallel_size,
|
||
max_model_len=max_model_len,
|
||
block_size=block_size,
|
||
enable_chunked_prefill=enable_chunked_prefill,
|
||
quantization=quantization,
|
||
**kwargs,
|
||
)
|
||
|
||
cleanup_dist_env_and_memory()
|
||
self._start_data_parallel_workers(llm_kwargs)
|
||
|
||
@property
|
||
def model(self) -> LLM:
|
||
raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")
|
||
|
||
def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
|
||
ctx = multiprocessing.get_context("spawn")
|
||
master_port = get_open_port()
|
||
|
||
try:
|
||
for dp_rank in range(self._dp_size):
|
||
parent_conn, child_conn = ctx.Pipe()
|
||
proc = ctx.Process(
|
||
target=_run_vllm_runner_dp_worker,
|
||
args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
|
||
)
|
||
proc.start()
|
||
child_conn.close()
|
||
self._dp_parent_conns.append(parent_conn)
|
||
self._dp_processes.append(proc)
|
||
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
if not conn.poll(self._dp_start_timeout):
|
||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
|
||
message = conn.recv()
|
||
if message["status"] != "ready":
|
||
raise RuntimeError(
|
||
f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
|
||
)
|
||
except Exception:
|
||
self._stop_data_parallel_workers()
|
||
raise
|
||
|
||
def _stop_data_parallel_workers(self) -> None:
|
||
for conn in self._dp_parent_conns:
|
||
with contextlib.suppress(Exception):
|
||
conn.send({"command": "shutdown"})
|
||
|
||
for proc in self._dp_processes:
|
||
proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
|
||
if proc.is_alive():
|
||
proc.kill()
|
||
proc.join(timeout=5)
|
||
|
||
for conn in self._dp_parent_conns:
|
||
with contextlib.suppress(Exception):
|
||
conn.close()
|
||
|
||
self._dp_parent_conns.clear()
|
||
self._dp_processes.clear()
|
||
|
||
def _dispatch_prompt_command(
|
||
self,
|
||
command: str,
|
||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||
*,
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
**payload: Any,
|
||
) -> list[Any]:
|
||
if not prompts:
|
||
return []
|
||
|
||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
indices = shard_indices[rank]
|
||
worker_indices = indices or [0]
|
||
worker_prompts = _slice_list_inputs(prompts, worker_indices)
|
||
conn.send(
|
||
{
|
||
"command": command,
|
||
"indices": indices,
|
||
"inputs": self.get_inputs(
|
||
worker_prompts,
|
||
images=_slice_optional_inputs(images, worker_indices),
|
||
videos=_slice_optional_inputs(videos, worker_indices),
|
||
audios=_slice_optional_inputs(audios, worker_indices),
|
||
),
|
||
"prompts": worker_prompts,
|
||
**payload,
|
||
}
|
||
)
|
||
|
||
try:
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
if not conn.poll(self._dp_request_timeout):
|
||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||
message = conn.recv()
|
||
if message["status"] != "ok":
|
||
raise RuntimeError(
|
||
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||
f"{message.get('traceback', 'unknown error')}"
|
||
)
|
||
shard_results.append((message["indices"], message["result"]))
|
||
except Exception:
|
||
self._stop_data_parallel_workers()
|
||
raise
|
||
|
||
return _merge_data_parallel_results(len(prompts), shard_results)
|
||
|
||
def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
|
||
if not prompts:
|
||
return []
|
||
|
||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
||
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
indices = shard_indices[rank]
|
||
worker_indices = indices or [0]
|
||
conn.send(
|
||
{
|
||
"command": command,
|
||
"indices": indices,
|
||
"prompts": _slice_list_inputs(prompts, worker_indices),
|
||
}
|
||
)
|
||
|
||
try:
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
if not conn.poll(self._dp_request_timeout):
|
||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
||
message = conn.recv()
|
||
if message["status"] != "ok":
|
||
raise RuntimeError(
|
||
f"Data parallel worker {rank} failed during `{command}`:\n"
|
||
f"{message.get('traceback', 'unknown error')}"
|
||
)
|
||
shard_results.append((message["indices"], message["result"]))
|
||
except Exception:
|
||
self._stop_data_parallel_workers()
|
||
raise
|
||
|
||
return _merge_data_parallel_results(len(prompts), shard_results)
|
||
|
||
def generate(
|
||
self,
|
||
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
||
sampling_params: SamplingParams,
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
**kwargs: Any,
|
||
) -> list[tuple[list[list[int]], list[str]]]:
|
||
return self._dispatch_prompt_command(
|
||
"generate",
|
||
prompts,
|
||
images=images,
|
||
videos=videos,
|
||
audios=audios,
|
||
sampling_params=sampling_params,
|
||
kwargs=kwargs,
|
||
)
|
||
|
||
def generate_w_logprobs(
|
||
self,
|
||
prompts: list[str],
|
||
sampling_params: SamplingParams,
|
||
images: PromptImageInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
**kwargs: Any,
|
||
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
||
toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
|
||
"generate_w_logprobs",
|
||
prompts,
|
||
images=images,
|
||
videos=videos,
|
||
audios=audios,
|
||
sampling_params=sampling_params,
|
||
kwargs=kwargs,
|
||
)
|
||
return (
|
||
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||
if sampling_params.prompt_logprobs is None
|
||
else toks_str_logsprobs_prompt_logprobs
|
||
)
|
||
|
||
def classify(self, prompts: list[str]) -> list[list[float]]:
|
||
return self._dispatch_text_command("classify", prompts)
|
||
|
||
def embed(
|
||
self,
|
||
prompts: list[str],
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
*args,
|
||
**kwargs,
|
||
) -> list[list[float]]:
|
||
return self._dispatch_prompt_command(
|
||
"embed",
|
||
prompts,
|
||
images=images,
|
||
videos=videos,
|
||
audios=audios,
|
||
args=args,
|
||
kwargs=kwargs,
|
||
)
|
||
|
||
def encode(self, prompts: list[str]) -> list[list[float]]:
|
||
return self._dispatch_text_command("encode", prompts)
|
||
|
||
def reward(self, prompts: list[str]) -> list[list[float]]:
|
||
return self._dispatch_text_command("reward", prompts)
|
||
|
||
def score(
|
||
self,
|
||
text_1: str | list[str],
|
||
text_2: str | list[str],
|
||
*args,
|
||
**kwargs,
|
||
) -> list[float]:
|
||
normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
|
||
if not normalized_text_1:
|
||
return []
|
||
|
||
shard_results: list[tuple[list[int], list[Any]]] = []
|
||
shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)
|
||
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
indices = shard_indices[rank]
|
||
worker_indices = indices or [0]
|
||
conn.send(
|
||
{
|
||
"command": "score",
|
||
"indices": indices,
|
||
"text_1": _slice_list_inputs(normalized_text_1, worker_indices),
|
||
"text_2": _slice_list_inputs(normalized_text_2, worker_indices),
|
||
"args": args,
|
||
"kwargs": kwargs,
|
||
}
|
||
)
|
||
|
||
try:
|
||
for rank, conn in enumerate(self._dp_parent_conns):
|
||
if not conn.poll(self._dp_request_timeout):
|
||
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
|
||
message = conn.recv()
|
||
if message["status"] != "ok":
|
||
raise RuntimeError(
|
||
f"Data parallel worker {rank} failed during `score`:\n"
|
||
f"{message.get('traceback', 'unknown error')}"
|
||
)
|
||
shard_results.append((message["indices"], message["result"]))
|
||
except Exception:
|
||
self._stop_data_parallel_workers()
|
||
raise
|
||
|
||
return _merge_data_parallel_results(len(normalized_text_1), shard_results)
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
self._stop_data_parallel_workers()
|
||
clear_ascend_config()
|
||
cleanup_dist_env_and_memory()
|
||
|
||
|
||
DataParallelVllmRunner = DPVllmRunner
|
||
|
||
|
||
class HfRunner:
|
||
def get_default_device(self):
|
||
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
||
|
||
def wrap_device(self, x: _T, device: str | None = None) -> _T:
|
||
if x is None or isinstance(x, (bool,)):
|
||
return x
|
||
|
||
if device is None:
|
||
device = self.device
|
||
|
||
if isinstance(x, dict):
|
||
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
||
|
||
if hasattr(x, "device") and x.device.type == device:
|
||
return x
|
||
|
||
return x.to(device)
|
||
|
||
def __init__(
|
||
self,
|
||
model_name: str,
|
||
dtype: str = "auto",
|
||
*,
|
||
model_kwargs: dict[str, Any] | None = None,
|
||
trust_remote_code: bool = True,
|
||
is_sentence_transformer: bool = False,
|
||
is_cross_encoder: bool = False,
|
||
skip_tokenizer_init: bool = False,
|
||
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
||
) -> None:
|
||
model_name = maybe_model_redirect(model_name)
|
||
self.model_name = model_name
|
||
|
||
self.config = AutoConfig.from_pretrained(
|
||
model_name,
|
||
trust_remote_code=trust_remote_code,
|
||
)
|
||
self.device = self.get_default_device()
|
||
self.dtype = torch_dtype = _get_and_verify_dtype(
|
||
self.model_name,
|
||
self.config,
|
||
dtype=dtype,
|
||
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
||
)
|
||
|
||
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
||
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
||
|
||
if is_sentence_transformer:
|
||
# Lazy init required for AMD CI
|
||
from sentence_transformers import SentenceTransformer
|
||
|
||
self.model = SentenceTransformer(
|
||
model_name,
|
||
device=self.device,
|
||
model_kwargs=model_kwargs,
|
||
trust_remote_code=trust_remote_code,
|
||
)
|
||
elif is_cross_encoder:
|
||
# Lazy init required for AMD CI
|
||
from sentence_transformers import CrossEncoder
|
||
|
||
self.model = CrossEncoder(
|
||
model_name,
|
||
device=self.device,
|
||
automodel_args=model_kwargs,
|
||
trust_remote_code=trust_remote_code,
|
||
)
|
||
else:
|
||
model = auto_cls.from_pretrained(
|
||
model_name,
|
||
trust_remote_code=trust_remote_code,
|
||
**model_kwargs,
|
||
)
|
||
|
||
# in case some unquantized custom models are not in same dtype
|
||
if getattr(model, "quantization_method", None) is None and any(
|
||
p.dtype != self.dtype for p in model.parameters()
|
||
):
|
||
model = model.to(dtype=self.dtype)
|
||
|
||
if (
|
||
getattr(model, "quantization_method", None) != "bitsandbytes"
|
||
and len({p.device for p in model.parameters()}) < 2
|
||
):
|
||
model = model.to(device=self.device)
|
||
|
||
self.model = model
|
||
|
||
if not skip_tokenizer_init:
|
||
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch_dtype,
|
||
trust_remote_code=trust_remote_code,
|
||
)
|
||
|
||
# don't put this import at the top level
|
||
# it will call torch.cuda.device_count()
|
||
from transformers import AutoProcessor # noqa: F401
|
||
|
||
self.processor = AutoProcessor.from_pretrained(
|
||
model_name,
|
||
torch_dtype=torch_dtype,
|
||
trust_remote_code=trust_remote_code,
|
||
)
|
||
if skip_tokenizer_init:
|
||
self.tokenizer = self.processor.tokenizer
|
||
|
||
def get_inputs(
|
||
self,
|
||
prompts: list[str],
|
||
images: PromptImageInput | None = None,
|
||
videos: PromptVideoInput | None = None,
|
||
audios: PromptAudioInput | None = None,
|
||
) -> list[BatchFeature | BatchEncoding]:
|
||
if images is not None:
|
||
assert len(prompts) == len(images)
|
||
|
||
if videos is not None:
|
||
assert len(prompts) == len(videos)
|
||
|
||
if audios is not None:
|
||
assert len(prompts) == len(audios)
|
||
|
||
all_inputs: list[BatchFeature | BatchEncoding] = []
|
||
for i, prompt in enumerate(prompts):
|
||
processor_kwargs: dict[str, Any] = {
|
||
"text": prompt,
|
||
"return_tensors": "pt",
|
||
}
|
||
if images is not None and (image := images[i]) is not None:
|
||
processor_kwargs["images"] = image
|
||
if videos is not None and (video := videos[i]) is not None:
|
||
processor_kwargs["videos"] = video
|
||
if audios is not None and (audio_inputs := audios[i]) is not None:
|
||
# HACK - not all processors take sampling_rate; we should
|
||
# clean this up in the future.
|
||
if len(audio_inputs) == 2:
|
||
audio, sr = audio_inputs
|
||
processor_kwargs["audio"] = audio
|
||
processor_kwargs["sampling_rate"] = sr
|
||
else:
|
||
processor_kwargs["audio"] = audio_inputs
|
||
|
||
inputs = self.processor(**processor_kwargs)
|
||
if isinstance(inputs, BatchFeature):
|
||
inputs = inputs.to(dtype=self.dtype)
|
||
|
||
all_inputs.append(inputs)
|
||
|
||
return all_inputs
|
||
|
||
def classify(self, prompts: list[str]) -> list[str]:
|
||
# output is final logits
|
||
all_inputs = self.get_inputs(prompts)
|
||
outputs = []
|
||
problem_type = getattr(self.config, "problem_type", "")
|
||
|
||
for inputs in all_inputs:
|
||
output = self.model(**self.wrap_device(inputs))
|
||
if problem_type == "regression":
|
||
logits = output.logits[0].tolist()
|
||
elif problem_type == "multi_label_classification":
|
||
logits = output.logits.sigmoid()[0].tolist()
|
||
else:
|
||
logits = output.logits.softmax(dim=-1)[0].tolist()
|
||
outputs.append(logits)
|
||
|
||
return outputs
|
||
|
||
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
||
return self.model.encode(prompts, *args, **kwargs)
|
||
|
||
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
||
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
||
|
||
def __enter__(self):
|
||
return self
|
||
|
||
def __exit__(self, exc_type, exc_value, traceback):
|
||
del self.model
|
||
cleanup_dist_env_and_memory()
|
||
|
||
|
||
@pytest.fixture(scope="session")
|
||
def ilama_lora_files():
|
||
return snapshot_download(
|
||
repo_id="vllm-ascend/ilama-text2sql-spider",
|
||
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
||
)
|
||
|
||
|
||
@pytest.fixture(scope="session")
|
||
def llama32_lora_files():
|
||
from huggingface_hub import snapshot_download as hf_snapshot_download
|
||
|
||
return hf_snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider", local_files_only=True)
|
||
|
||
|
||
def qwen_prompt(questions: list[str]) -> list[str]:
|
||
placeholder = "<|image_pad|>"
|
||
return [
|
||
(
|
||
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
||
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
||
f"{q}<|im_end|>\n<|im_start|>assistant\n"
|
||
)
|
||
for q in questions
|
||
]
|
||
|
||
|
||
def hunyuan_prompt(questions: list[str]) -> list[str]:
|
||
placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
|
||
return [f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" for question in questions]
|
||
|
||
|
||
PROMPT_CONFIGS = {
|
||
"qwen-vl": {
|
||
"model": "Qwen/Qwen3-VL-8B-Instruct",
|
||
"prompt_fn": qwen_prompt,
|
||
"mm_processor_kwargs": {
|
||
"min_pixels": 28 * 28,
|
||
"max_pixels": 1280 * 28 * 28,
|
||
"fps": 1,
|
||
},
|
||
},
|
||
"hunyuan-vl": {
|
||
"model": "Tencent-Hunyuan/HunyuanOCR",
|
||
"prompt_fn": hunyuan_prompt,
|
||
"mm_processor_kwargs": {},
|
||
},
|
||
}
|
||
|
||
|
||
@pytest.fixture(params=PROMPT_CONFIGS.keys())
|
||
def vl_config(request):
|
||
config = PROMPT_CONFIGS[request.param]
|
||
if "skip" in config:
|
||
pytest.skip(config["skip"])
|
||
return config
|