2025-02-05 10:53:12 +08:00
|
|
|
|
#
|
|
|
|
|
|
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
|
|
|
|
# Copyright 2023 The vLLM team.
|
|
|
|
|
|
#
|
|
|
|
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
|
|
|
|
# you may not use this file except in compliance with the License.
|
|
|
|
|
|
# You may obtain a copy of the License at
|
|
|
|
|
|
#
|
|
|
|
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
|
|
|
|
#
|
|
|
|
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
|
|
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
|
|
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
|
|
|
|
# See the License for the specific language governing permissions and
|
|
|
|
|
|
# limitations under the License.
|
2025-04-17 14:59:56 +08:00
|
|
|
|
# This file is a part of the vllm-ascend project.
|
|
|
|
|
|
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
2025-02-05 10:53:12 +08:00
|
|
|
|
#
|
|
|
|
|
|
|
2025-06-14 16:59:00 +08:00
|
|
|
|
import contextlib
|
2026-02-06 17:30:17 +08:00
|
|
|
|
import copy
|
2026-01-20 21:05:44 +08:00
|
|
|
|
import functools
|
2025-04-08 16:52:45 +08:00
|
|
|
|
import gc
|
2025-10-11 14:50:46 +08:00
|
|
|
|
import json
|
2025-12-04 22:31:07 +08:00
|
|
|
|
import logging
|
2026-01-20 21:05:44 +08:00
|
|
|
|
import multiprocessing
|
2025-07-02 16:57:03 +08:00
|
|
|
|
import os
|
2025-10-17 09:04:31 +08:00
|
|
|
|
import shlex
|
2025-10-11 14:50:46 +08:00
|
|
|
|
import subprocess
|
|
|
|
|
|
import sys
|
2026-02-06 17:30:17 +08:00
|
|
|
|
import threading
|
2025-10-11 14:50:46 +08:00
|
|
|
|
import time
|
2026-02-06 17:30:17 +08:00
|
|
|
|
import traceback
|
|
|
|
|
|
from pathlib import Path
|
2026-03-10 09:52:50 +08:00
|
|
|
|
from typing import Any, TypeVar
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
[CI][Misc] Use offline mode for model downloads (#7179)
### What this PR does / why we need it?
1. For all parts of the current test module involving the millisecond
download model, add the `local_file_only` parameter to specify offline
mode; this ensures that CI will not fail due to network instability.
2. Install modelscope from a fixed commit until it next release
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
check if the env or arg `local_files_only` works
1) set the env:
```shell
export HF_HUB_OFFLINE=1
```
2) run the script
```python
from transformers import PretrainedConfig
import huggingface_hub
from modelscope.utils.hf_util import patch_hub
patch_hub()
model="Qwen/Qwen3-0.6B"
kwargs = {}
config_dict, _ = PretrainedConfig.get_config_dict(
model,
trust_remote_code=True,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
**kwargs,
)
print(config_dict)
```
it works well:
```shell
2026-03-06 06:40:12,546 - modelscope - WARNING - We can not confirm the cached file is for revision: master
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
{'architectures': ['Qwen3ForCausalLM'], 'attention_bias': False, 'attention_dropout': 0.0, 'bos_token_id': 151643, 'eos_token_id': 151645, 'head_dim': 128, 'hidden_act': 'silu', 'hidden_size': 1024, 'initializer_range': 0.02, 'intermediate_size': 3072, 'max_position_embeddings': 40960, 'max_window_layers': 28, 'model_type': 'qwen3', 'num_attention_heads': 16, 'num_hidden_layers': 28, 'num_key_value_heads': 8, 'rms_norm_eps': 1e-06, 'rope_scaling': None, 'rope_theta': 1000000, 'sliding_window': None, 'tie_word_embeddings': True, 'torch_dtype': 'bfloat16', 'transformers_version': '4.51.0', 'use_cache': True, 'use_sliding_window': False, 'vocab_size': 151936, '_commit_hash': None}
```
3) test the model repo does not cached locally when the env
`HF_HUB_OFFLINE`==True
```python
from transformers import PretrainedConfig
import huggingface_hub
from modelscope.utils.hf_util import patch_hub
patch_hub()
model="FireRedTeam/FireRed-OCR"
kwargs = {}
config_dict, _ = PretrainedConfig.get_config_dict(
model,
trust_remote_code=True,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
**kwargs,
)
print(config_dict)
```
and the result is as expected:
```shell
File "/workspace/demo.py", line 12, in <module>
config_dict, _ = PretrainedConfig.get_config_dict(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/utils/hf_util/patcher.py", line 189, in patch_get_config_dict
model_dir = get_model_dir(pretrained_model_name_or_path,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/utils/hf_util/patcher.py", line 164, in get_model_dir
model_dir = snapshot_download(
^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/hub/snapshot_download.py", line 137, in snapshot_download
return _snapshot_download(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/hub/snapshot_download.py", line 283, in _snapshot_download
raise ValueError(
ValueError: Cannot find the requested files in the cached path and outgoing traffic has been disabled. To enable look-ups and downloads online, set 'local_files_only' to False
```
- vLLM version: v0.16.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
2026-03-13 08:52:24 +08:00
|
|
|
|
import huggingface_hub
|
2025-02-05 10:53:12 +08:00
|
|
|
|
import numpy as np
|
2025-10-11 14:50:46 +08:00
|
|
|
|
import openai
|
2026-02-06 17:30:17 +08:00
|
|
|
|
import psutil
|
2025-02-05 10:53:12 +08:00
|
|
|
|
import pytest
|
2025-10-11 14:50:46 +08:00
|
|
|
|
import requests
|
2025-04-08 16:52:45 +08:00
|
|
|
|
import torch
|
2025-07-04 10:52:17 +08:00
|
|
|
|
from modelscope import snapshot_download # type: ignore[import-untyped]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
from PIL import Image
|
2025-12-04 22:31:07 +08:00
|
|
|
|
from requests.exceptions import RequestException
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
from torch import nn
|
2026-03-10 09:52:50 +08:00
|
|
|
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, BatchEncoding, BatchFeature
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
from transformers.models.auto.auto_factory import _BaseAutoModelClass
|
2025-02-05 10:53:12 +08:00
|
|
|
|
from vllm import LLM, SamplingParams
|
2026-03-10 09:52:50 +08:00
|
|
|
|
from vllm.config.model import ConvertOption, RunnerOption, _get_and_verify_dtype
|
2025-09-02 09:02:22 +08:00
|
|
|
|
from vllm.inputs import TextPrompt
|
2025-02-05 10:53:12 +08:00
|
|
|
|
from vllm.outputs import RequestOutput
|
2025-10-11 14:50:46 +08:00
|
|
|
|
from vllm.platforms import current_platform
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
from vllm.transformers_utils.utils import maybe_model_redirect
|
2025-11-24 17:08:20 +08:00
|
|
|
|
from vllm.utils.network_utils import get_open_port
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
from tests.e2e.model_utils import TokensTextLogprobs, TokensTextLogprobsPromptLogprobs
|
|
|
|
|
|
from tests.e2e.nightly.multi_node.scripts.multi_node_config import DisaggregatedPrefillCfg, NodeInfo
|
2025-09-02 18:49:17 +08:00
|
|
|
|
from vllm_ascend.ascend_config import clear_ascend_config
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2025-04-19 17:38:18 +08:00
|
|
|
|
# TODO: remove this part after the patch merged into vllm, if
|
|
|
|
|
|
# we not explicitly patch here, some of them might be effectiveless
|
|
|
|
|
|
# in pytest scenario
|
|
|
|
|
|
from vllm_ascend.utils import adapt_patch # noqa E402
|
|
|
|
|
|
|
|
|
|
|
|
adapt_patch(True)
|
2025-07-15 00:54:20 +08:00
|
|
|
|
adapt_patch(False)
|
2025-04-19 17:38:18 +08:00
|
|
|
|
|
|
|
|
|
|
from vllm.distributed.parallel_state import ( # noqa E402
|
2026-03-10 09:52:50 +08:00
|
|
|
|
destroy_distributed_environment,
|
|
|
|
|
|
destroy_model_parallel,
|
|
|
|
|
|
)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
_M = TypeVar("_M")
|
2025-04-08 16:52:45 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
_PromptMultiModalInput = list[_M] | list[list[_M]]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
2026-03-10 09:52:50 +08:00
|
|
|
|
PromptAudioInput = _PromptMultiModalInput[tuple[np.ndarray, int]]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
2025-12-04 22:31:07 +08:00
|
|
|
|
logger = logging.getLogger(__name__)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
2025-07-02 16:57:03 +08:00
|
|
|
|
_TEST_DIR = os.path.dirname(__file__)
|
2026-02-03 15:04:14 +08:00
|
|
|
|
_LONG_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "long_prompt.txt")]
|
2025-07-02 16:57:03 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
DISAGG_EPD_PROXY_SCRIPT = (
|
|
|
|
|
|
Path(__file__).parent.parent.parent / "examples" / "disaggregated_encoder" / "disagg_epd_proxy.py"
|
|
|
|
|
|
)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
def _check_npu_memory_worker(target_free_percentage: float, max_wait_seconds: float):
|
|
|
|
|
|
# We can try to clean up memory in this subprocess, though it mostly affects this process.
|
|
|
|
|
|
# But if there are any lingering contexts in this process (unlikely for a fresh spawn), it helps.
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
torch.npu.empty_cache()
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
_, total_npu_memory = torch.npu.mem_get_info()
|
|
|
|
|
|
start_time = time.time()
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
free_bytes, _ = torch.npu.mem_get_info()
|
|
|
|
|
|
if free_bytes / total_npu_memory >= target_free_percentage:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
print("check_npu_memory_worker: npu free memory decreased target value.")
|
2026-01-20 21:05:44 +08:00
|
|
|
|
return # Success
|
|
|
|
|
|
|
|
|
|
|
|
elapsed = time.time() - start_time
|
|
|
|
|
|
if elapsed > max_wait_seconds:
|
|
|
|
|
|
# Print to stderr so it's visible in test logs even if captured
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"Timeout: NPU memory free size did not reach "
|
|
|
|
|
|
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds.",
|
2026-03-10 09:52:50 +08:00
|
|
|
|
file=sys.stderr,
|
2026-01-20 21:05:44 +08:00
|
|
|
|
)
|
|
|
|
|
|
sys.exit(1) # Failure
|
|
|
|
|
|
|
|
|
|
|
|
print(
|
|
|
|
|
|
f"Waiting for NPU memory to be free: "
|
|
|
|
|
|
f"{free_bytes / 1024**3:.2f} GB available, "
|
|
|
|
|
|
f"Elapsed time: {elapsed:.2f} s."
|
|
|
|
|
|
)
|
|
|
|
|
|
# Try to clean up
|
|
|
|
|
|
gc.collect()
|
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
time.sleep(1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def wait_until_npu_memory_free(target_free_percentage: float = 0.5, max_wait_seconds: float = 50):
|
|
|
|
|
|
"""Decorator to wait until the NPU memory free size is above target_free_percentage.
|
|
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
|
target_free_percentage (float): Target free memory percentage of total.
|
|
|
|
|
|
max_wait_seconds (float): Maximum wait time in seconds.
|
|
|
|
|
|
"""
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
def decorator(func):
|
|
|
|
|
|
@functools.wraps(func)
|
|
|
|
|
|
def wrapper(*args, **kwargs):
|
|
|
|
|
|
# Clean up non-NPU resources in the main process
|
|
|
|
|
|
cleanup_dist_env_and_memory()
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
# Use a spawned subprocess to check NPU memory to avoid initializing NPU in the main process
|
|
|
|
|
|
ctx = multiprocessing.get_context("spawn")
|
2026-03-10 09:52:50 +08:00
|
|
|
|
p = ctx.Process(target=_check_npu_memory_worker, args=(target_free_percentage, max_wait_seconds))
|
2026-01-20 21:05:44 +08:00
|
|
|
|
p.start()
|
|
|
|
|
|
p.join()
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
if p.exitcode != 0:
|
|
|
|
|
|
raise TimeoutError(
|
|
|
|
|
|
f"Timeout: NPU memory free size did not reach "
|
|
|
|
|
|
f"{target_free_percentage} of total npu memory within {max_wait_seconds} seconds."
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
return func(*args, **kwargs)
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
return wrapper
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
return decorator
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-06-14 16:59:00 +08:00
|
|
|
|
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
2025-04-08 16:52:45 +08:00
|
|
|
|
destroy_model_parallel()
|
|
|
|
|
|
destroy_distributed_environment()
|
2025-06-14 16:59:00 +08:00
|
|
|
|
with contextlib.suppress(AssertionError):
|
|
|
|
|
|
torch.distributed.destroy_process_group()
|
|
|
|
|
|
if shutdown_ray:
|
|
|
|
|
|
import ray # Lazy import Ray
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2025-06-14 16:59:00 +08:00
|
|
|
|
ray.shutdown()
|
2025-04-08 16:52:45 +08:00
|
|
|
|
gc.collect()
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-20 21:05:44 +08:00
|
|
|
|
# Only clean NPU cache if NPU is already initialized/available in this process.
|
|
|
|
|
|
# This prevents accidental initialization of NPU context in the main process,
|
|
|
|
|
|
# which would break subsequent forks.
|
|
|
|
|
|
if hasattr(torch, "npu") and torch.npu.is_initialized():
|
|
|
|
|
|
torch.npu.empty_cache()
|
|
|
|
|
|
torch.npu.reset_peak_memory_stats()
|
2025-04-08 16:52:45 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-01-23 17:14:15 +08:00
|
|
|
|
class MooncakeLauncher:
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
mooncake_port,
|
|
|
|
|
|
mooncake_metrics_port,
|
|
|
|
|
|
eviction_high_watermark_ratio=0.8,
|
|
|
|
|
|
eviction_ratio=0.05,
|
|
|
|
|
|
):
|
|
|
|
|
|
self.mooncake_port = mooncake_port
|
|
|
|
|
|
self.mooncake_metrics_port = mooncake_metrics_port
|
|
|
|
|
|
self.eviction_high_watermark_ratio = eviction_high_watermark_ratio
|
|
|
|
|
|
self.eviction_ratio = eviction_ratio
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
cmd = [
|
|
|
|
|
|
"mooncake_master",
|
|
|
|
|
|
"--eviction_high_watermark_ratio",
|
|
|
|
|
|
str(self.eviction_high_watermark_ratio),
|
|
|
|
|
|
"--eviction_ratio",
|
|
|
|
|
|
str(self.eviction_ratio),
|
|
|
|
|
|
"--port",
|
|
|
|
|
|
str(self.mooncake_port),
|
|
|
|
|
|
"--metrics_port",
|
|
|
|
|
|
str(self.mooncake_metrics_port),
|
|
|
|
|
|
]
|
|
|
|
|
|
|
|
|
|
|
|
logger.info("Launching mooncake: %s", " ".join(cmd))
|
|
|
|
|
|
curr_ld_path = os.environ.get("LD_LIBRARY_PATH", "")
|
|
|
|
|
|
mooncake_ld_path = "/usr/local/Ascend/ascend-toolkit/latest/python/site-packages/mooncake:"
|
|
|
|
|
|
os.environ["LD_LIBRARY_PATH"] = mooncake_ld_path + curr_ld_path
|
|
|
|
|
|
env = os.environ.copy()
|
|
|
|
|
|
self.process = subprocess.Popen(cmd, env=env)
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc, tb):
|
|
|
|
|
|
if not self.process:
|
|
|
|
|
|
return
|
|
|
|
|
|
logger.info("Stopping mooncake server...")
|
|
|
|
|
|
self.process.terminate()
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.process.wait(timeout=5)
|
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
|
|
self.process.kill()
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-10-11 14:50:46 +08:00
|
|
|
|
class RemoteOpenAIServer:
|
|
|
|
|
|
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
|
|
|
|
|
|
"""Subclasses override this method to customize server process launch"""
|
2025-10-11 14:50:46 +08:00
|
|
|
|
env = os.environ.copy()
|
|
|
|
|
|
# the current process might initialize npu,
|
|
|
|
|
|
# to be safe, we should use spawn method
|
2026-03-10 09:52:50 +08:00
|
|
|
|
env["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
2025-10-11 14:50:46 +08:00
|
|
|
|
if env_dict is not None:
|
|
|
|
|
|
env.update(env_dict)
|
2025-12-30 19:03:02 +08:00
|
|
|
|
logger.info(f"Starting server with command: {' '.join(server_cmd)}")
|
2025-10-11 14:50:46 +08:00
|
|
|
|
self.proc: subprocess.Popen = subprocess.Popen(
|
2025-10-17 09:04:31 +08:00
|
|
|
|
server_cmd,
|
2025-10-11 14:50:46 +08:00
|
|
|
|
env=env,
|
|
|
|
|
|
stdout=sys.stdout,
|
|
|
|
|
|
stderr=sys.stderr,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2025-12-30 19:03:02 +08:00
|
|
|
|
def __init__(
|
2026-03-10 09:52:50 +08:00
|
|
|
|
self,
|
|
|
|
|
|
model: str,
|
|
|
|
|
|
vllm_serve_args: list[str] | str,
|
|
|
|
|
|
*,
|
|
|
|
|
|
server_host: str = "0.0.0.0",
|
|
|
|
|
|
server_port: int = 8080,
|
|
|
|
|
|
env_dict: dict[str, str] | None = None,
|
|
|
|
|
|
seed: int | None = None,
|
|
|
|
|
|
auto_port: bool = True,
|
|
|
|
|
|
nodes_info: list[NodeInfo] | None = None,
|
|
|
|
|
|
disaggregated_prefill: DisaggregatedPrefillCfg | None = None,
|
|
|
|
|
|
proxy_port: int | None = None,
|
|
|
|
|
|
max_wait_seconds: float | None = None,
|
|
|
|
|
|
override_hf_configs: dict[str, Any] | None = None,
|
|
|
|
|
|
) -> None:
|
2025-10-17 09:04:31 +08:00
|
|
|
|
if isinstance(vllm_serve_args, str):
|
|
|
|
|
|
vllm_serve_args = shlex.split(vllm_serve_args)
|
|
|
|
|
|
else:
|
|
|
|
|
|
vllm_serve_args = ["vllm", "serve", model, *vllm_serve_args]
|
2025-10-11 14:50:46 +08:00
|
|
|
|
if auto_port:
|
|
|
|
|
|
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
raise ValueError("You have manually specified the port when `auto_port=True`.")
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
# No need for a port if using unix sockets
|
|
|
|
|
|
if "--uds" not in vllm_serve_args:
|
|
|
|
|
|
# Don't mutate the input args
|
2026-03-10 09:52:50 +08:00
|
|
|
|
vllm_serve_args = vllm_serve_args + ["--port", str(get_open_port())]
|
2025-10-11 14:50:46 +08:00
|
|
|
|
if seed is not None:
|
|
|
|
|
|
if "--seed" in vllm_serve_args:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
raise ValueError(f"You have manually specified the seed when `seed={seed}`.")
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
|
|
|
|
|
|
|
|
|
|
|
if override_hf_configs is not None:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
vllm_serve_args = vllm_serve_args + ["--hf-overrides", json.dumps(override_hf_configs)]
|
2025-10-25 09:23:47 +08:00
|
|
|
|
|
2025-10-17 09:04:31 +08:00
|
|
|
|
self.host = str(server_host)
|
|
|
|
|
|
self.port = int(server_port)
|
2025-10-25 09:23:47 +08:00
|
|
|
|
# for multi-nodes test
|
|
|
|
|
|
self.nodes_info = nodes_info
|
|
|
|
|
|
self.disaggregated_prefill = disaggregated_prefill
|
|
|
|
|
|
self.cur_index = os.getenv("LWS_WORKER_INDEX", 0)
|
|
|
|
|
|
self.proxy_port = proxy_port
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
self._start_server(model, vllm_serve_args, env_dict)
|
2025-12-19 21:21:42 +08:00
|
|
|
|
max_wait_seconds = max_wait_seconds or 2800
|
2025-10-25 09:23:47 +08:00
|
|
|
|
if self.disaggregated_prefill:
|
|
|
|
|
|
assert proxy_port is not None, "for disaggregated_prefill, proxy port must be provided"
|
2025-12-04 22:31:07 +08:00
|
|
|
|
self._wait_for_server_pd(timeout=max_wait_seconds)
|
2025-10-25 09:23:47 +08:00
|
|
|
|
else:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
self._wait_for_multiple_servers([(self.host, self.url_for("health"))], timeout=max_wait_seconds)
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
2025-12-04 22:31:07 +08:00
|
|
|
|
self._terminate_server()
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def _poll(self) -> int | None:
|
2025-10-11 14:50:46 +08:00
|
|
|
|
"""Subclasses override this method to customize process polling"""
|
|
|
|
|
|
return self.proc.poll()
|
|
|
|
|
|
|
2025-10-30 23:42:20 +08:00
|
|
|
|
def hang_until_terminated(self, url) -> None:
|
2025-10-11 14:50:46 +08:00
|
|
|
|
"""
|
|
|
|
|
|
Wait until the server process terminates.
|
|
|
|
|
|
This is for headless mode, where the api server
|
|
|
|
|
|
process only exists in the leader node.
|
|
|
|
|
|
"""
|
2025-12-30 19:03:02 +08:00
|
|
|
|
logger.info("Hanging until server process terminates...")
|
2025-10-17 09:04:31 +08:00
|
|
|
|
client = requests
|
2025-10-11 14:50:46 +08:00
|
|
|
|
try:
|
|
|
|
|
|
while True:
|
|
|
|
|
|
try:
|
2025-10-30 23:42:20 +08:00
|
|
|
|
resp = client.get(url, timeout=5)
|
2025-10-11 14:50:46 +08:00
|
|
|
|
if resp.status_code != 200:
|
|
|
|
|
|
break
|
|
|
|
|
|
time.sleep(5)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
break
|
|
|
|
|
|
finally:
|
2025-12-04 22:31:07 +08:00
|
|
|
|
self._terminate_server()
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
2025-12-04 22:31:07 +08:00
|
|
|
|
def _wait_for_server_pd(self, timeout: float):
|
2025-10-25 09:23:47 +08:00
|
|
|
|
# Wait for all api_server nodes ready
|
|
|
|
|
|
assert self.nodes_info is not None, "cluster info must be provided"
|
2025-12-04 22:31:07 +08:00
|
|
|
|
proxy_port = self.proxy_port
|
2025-10-25 09:23:47 +08:00
|
|
|
|
|
2025-12-04 22:31:07 +08:00
|
|
|
|
def url_health(ip: str, port: int) -> str:
|
|
|
|
|
|
return f"http://{ip}:{port}/health"
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
targets = [
|
|
|
|
|
|
(node_info.ip, url_health(node_info.ip, self.port))
|
|
|
|
|
|
for node_info in self.nodes_info
|
|
|
|
|
|
if not node_info.headless
|
|
|
|
|
|
]
|
2025-10-25 09:23:47 +08:00
|
|
|
|
|
|
|
|
|
|
# Wait for proxy ready
|
|
|
|
|
|
master_node = self.nodes_info[0]
|
|
|
|
|
|
url_proxy = f"http://{master_node.ip}:{proxy_port}/healthcheck"
|
|
|
|
|
|
|
2025-12-04 22:31:07 +08:00
|
|
|
|
# Wait for master node proxy first
|
2026-03-10 09:52:50 +08:00
|
|
|
|
self._wait_for_multiple_servers([(master_node.ip, url_proxy)], timeout=timeout)
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
|
|
|
|
|
# Then wait for all api_server nodes
|
|
|
|
|
|
self._wait_for_multiple_servers(targets=targets, timeout=timeout)
|
|
|
|
|
|
|
2026-03-20 11:33:48 +08:00
|
|
|
|
def _wait_for_multiple_servers(
|
|
|
|
|
|
self, targets, timeout: float, log_interval: float = 30.0, always_check_nodes: bool = False
|
|
|
|
|
|
):
|
2025-12-04 22:31:07 +08:00
|
|
|
|
"""
|
|
|
|
|
|
targets: List[(node_ip, url)]
|
2025-12-10 09:24:19 +08:00
|
|
|
|
log_interval
|
2025-12-04 22:31:07 +08:00
|
|
|
|
"""
|
2025-10-11 14:50:46 +08:00
|
|
|
|
start = time.time()
|
2025-10-17 09:04:31 +08:00
|
|
|
|
client = requests
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
|
|
|
|
|
ready = {node_ip: False for node_ip, _ in targets}
|
|
|
|
|
|
|
2025-12-10 09:24:19 +08:00
|
|
|
|
last_log_time = 0.0
|
|
|
|
|
|
|
2025-10-11 14:50:46 +08:00
|
|
|
|
while True:
|
2025-12-10 09:24:19 +08:00
|
|
|
|
now = time.time()
|
2025-12-04 22:31:07 +08:00
|
|
|
|
all_ready = True
|
2025-12-10 09:24:19 +08:00
|
|
|
|
should_log = (now - last_log_time) >= log_interval
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
|
|
|
|
|
for node_ip, url in targets:
|
2026-03-20 11:33:48 +08:00
|
|
|
|
if ready[node_ip] and not always_check_nodes:
|
2025-12-10 09:24:19 +08:00
|
|
|
|
continue
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
resp = client.get(url)
|
|
|
|
|
|
if resp.status_code == 200:
|
|
|
|
|
|
ready[node_ip] = True
|
2026-03-20 11:33:48 +08:00
|
|
|
|
logger.info(f"[READY] Node {node_ip}: {url} is ready.")
|
2025-12-04 22:31:07 +08:00
|
|
|
|
except RequestException:
|
|
|
|
|
|
all_ready = False
|
2025-12-10 09:24:19 +08:00
|
|
|
|
if should_log:
|
2025-12-26 11:39:07 +08:00
|
|
|
|
logger.debug(f"[WAIT] {url}: connection failed")
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
2025-12-10 09:24:19 +08:00
|
|
|
|
# check unexpected exit
|
2025-12-04 22:31:07 +08:00
|
|
|
|
result = self._poll()
|
|
|
|
|
|
if result is not None and result != 0:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
raise RuntimeError(f"Server at {node_ip} exited unexpectedly.") from None
|
2025-12-04 22:31:07 +08:00
|
|
|
|
|
2025-12-10 09:24:19 +08:00
|
|
|
|
if should_log:
|
|
|
|
|
|
last_log_time = now
|
|
|
|
|
|
|
2025-12-04 22:31:07 +08:00
|
|
|
|
if all_ready:
|
|
|
|
|
|
break
|
|
|
|
|
|
|
2025-12-10 09:24:19 +08:00
|
|
|
|
if now - start > timeout:
|
2025-12-04 22:31:07 +08:00
|
|
|
|
not_ready_nodes = [n for n, ok in ready.items() if not ok]
|
|
|
|
|
|
self._terminate_server()
|
|
|
|
|
|
raise RuntimeError(
|
2025-12-10 09:24:19 +08:00
|
|
|
|
f"Timeout: these nodes did not become ready: {not_ready_nodes} in time: {timeout}s"
|
2025-12-04 22:31:07 +08:00
|
|
|
|
) from None
|
|
|
|
|
|
|
|
|
|
|
|
time.sleep(5)
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def url_root(self) -> str:
|
2025-10-17 09:04:31 +08:00
|
|
|
|
return f"http://{self.host}:{self.port}"
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
2025-12-04 22:31:07 +08:00
|
|
|
|
def _terminate_server(self) -> None:
|
|
|
|
|
|
"""Subclasses override this method to customize server process termination"""
|
|
|
|
|
|
self.proc.terminate()
|
|
|
|
|
|
try:
|
|
|
|
|
|
self.proc.wait(8)
|
|
|
|
|
|
except subprocess.TimeoutExpired:
|
|
|
|
|
|
# force kill if needed
|
|
|
|
|
|
self.proc.kill()
|
|
|
|
|
|
|
2025-10-11 14:50:46 +08:00
|
|
|
|
def url_for(self, *parts: str) -> str:
|
|
|
|
|
|
return self.url_root + "/" + "/".join(parts)
|
|
|
|
|
|
|
|
|
|
|
|
def get_client(self, **kwargs):
|
|
|
|
|
|
if "timeout" not in kwargs:
|
|
|
|
|
|
kwargs["timeout"] = 600
|
|
|
|
|
|
return openai.OpenAI(
|
|
|
|
|
|
base_url=self.url_for("v1"),
|
|
|
|
|
|
api_key=self.DUMMY_API_KEY,
|
|
|
|
|
|
max_retries=0,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def get_async_client(self, **kwargs):
|
|
|
|
|
|
if "timeout" not in kwargs:
|
|
|
|
|
|
kwargs["timeout"] = 600
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return openai.AsyncOpenAI(base_url=self.url_for("v1"), api_key=self.DUMMY_API_KEY, max_retries=0, **kwargs)
|
2025-10-11 14:50:46 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-02-06 17:30:17 +08:00
|
|
|
|
class RemoteEPDServer(RemoteOpenAIServer):
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def _start_server(self, model: str, server_cmd: list[str], env_dict: dict[str, str] | None) -> None:
|
|
|
|
|
|
"""Subclasses override this method to customize server process launch"""
|
2026-02-06 17:30:17 +08:00
|
|
|
|
raise NotImplementedError("RemoteEPDServer should use _start_server_with_prefix instead")
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
vllm_serve_args: list[str] | list[list[str]],
|
|
|
|
|
|
server_host: str = "0.0.0.0",
|
|
|
|
|
|
env_dict: dict[str, str] | None = None,
|
|
|
|
|
|
max_wait_seconds: float | None = 2800,
|
|
|
|
|
|
) -> None:
|
2026-02-06 17:30:17 +08:00
|
|
|
|
self._proc_list = []
|
|
|
|
|
|
|
|
|
|
|
|
self.env_dict: dict[str, str] = {}
|
|
|
|
|
|
if env_dict is not None:
|
|
|
|
|
|
self.env_dict.update(env_dict)
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
self.env_dict["VLLM_ALLOW_LONG_MAX_MODEL_LEN"] = "1"
|
|
|
|
|
|
self.env_dict["PYTORCH_NPU_ALLOC_CONF"] = "expandable_segments:True"
|
|
|
|
|
|
self.env_dict["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
|
|
|
|
|
self.vllm_serve_args_list = []
|
|
|
|
|
|
self.health_url_list = []
|
|
|
|
|
|
self.host = server_host
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(vllm_serve_args, list):
|
|
|
|
|
|
if not all(isinstance(item, list) for item in vllm_serve_args):
|
|
|
|
|
|
args_copy = copy.deepcopy(vllm_serve_args)
|
|
|
|
|
|
self.vllm_serve_args_list.append([str(arg) for arg in args_copy])
|
|
|
|
|
|
else:
|
|
|
|
|
|
self.vllm_serve_args_list = [
|
2026-03-10 09:52:50 +08:00
|
|
|
|
[str(arg) for arg in sublist] for sublist in copy.deepcopy(vllm_serve_args)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
]
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise RuntimeError("vllm_serves_args must be a list")
|
|
|
|
|
|
|
|
|
|
|
|
serve_arg_cmd = ["vllm", "serve"]
|
|
|
|
|
|
|
|
|
|
|
|
for i, vllm_serve_arg in enumerate(self.vllm_serve_args_list):
|
2026-03-10 09:52:50 +08:00
|
|
|
|
self.env_dict["ASCEND_RT_VISIBLE_DEVICES"] = str(i)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
if isinstance(vllm_serve_arg, list):
|
|
|
|
|
|
if "--port" not in vllm_serve_arg:
|
|
|
|
|
|
raise ValueError("You have manually specified the port ")
|
|
|
|
|
|
else:
|
|
|
|
|
|
port_arg = "--port"
|
|
|
|
|
|
try:
|
|
|
|
|
|
index = vllm_serve_arg.index(port_arg)
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
raise ValueError(f"--port not found in args: {vllm_serve_arg}")
|
|
|
|
|
|
port_str = vllm_serve_arg[index + 1]
|
|
|
|
|
|
self.port = int(port_str)
|
|
|
|
|
|
else:
|
|
|
|
|
|
vllm_serve_arg_str = str(vllm_serve_arg)
|
|
|
|
|
|
if "--port" not in vllm_serve_arg_str:
|
|
|
|
|
|
raise ValueError("You have manually specified the port ")
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unexpected type for vllm_serve_arg: {type(vllm_serve_arg)}")
|
|
|
|
|
|
|
|
|
|
|
|
self.health_url_list.append(super().url_for("health"))
|
|
|
|
|
|
vllm_serve_arg = [*serve_arg_cmd, *vllm_serve_arg]
|
2026-03-10 09:52:50 +08:00
|
|
|
|
proc = self._start_server_with_prefix(vllm_serve_arg, self.env_dict, f"[VLLM_{i}] ")
|
2026-02-06 17:30:17 +08:00
|
|
|
|
self._proc_list.append(proc)
|
|
|
|
|
|
|
|
|
|
|
|
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
2026-03-20 11:33:48 +08:00
|
|
|
|
super()._wait_for_multiple_servers(
|
|
|
|
|
|
[(self.host, url) for url in self.health_url_list], timeout=timeout_value, always_check_nodes=True
|
|
|
|
|
|
)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def _poll(self) -> int | None:
|
2026-02-06 17:30:17 +08:00
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
def _delete_shm(self) -> None:
|
|
|
|
|
|
for i, arg in enumerate(self.vllm_serve_args_list):
|
|
|
|
|
|
if "--ec-transfer-config" in arg:
|
|
|
|
|
|
index = arg.index("--ec-transfer-config")
|
|
|
|
|
|
config_str = arg[index + 1]
|
|
|
|
|
|
config_dict = json.loads(config_str)
|
|
|
|
|
|
ec_connector_extra_config = config_dict.get("ec_connector_extra_config", {})
|
|
|
|
|
|
shm_path = ec_connector_extra_config.get("shared_storage_path")
|
|
|
|
|
|
if shm_path:
|
|
|
|
|
|
args = ["rm", "-r", "-f", str(shm_path)]
|
|
|
|
|
|
print(f"delete shm_path is: {shm_path}")
|
|
|
|
|
|
self._start_server_with_prefix(args, None, "[DELETE] ")
|
|
|
|
|
|
|
|
|
|
|
|
def _read_output(self, pipe, prefix):
|
|
|
|
|
|
try:
|
|
|
|
|
|
with pipe:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
for line in iter(pipe.readline, ""):
|
2026-02-06 17:30:17 +08:00
|
|
|
|
if line:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
print(f"{prefix}: {line}", end="")
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
|
print(f"error: {e}")
|
|
|
|
|
|
traceback.print_exc()
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def _start_server_with_prefix(self, server_cmd: list[str], env_dict: dict[str, str] | None, log_prefix: str):
|
2026-02-06 17:30:17 +08:00
|
|
|
|
env = os.environ.copy()
|
|
|
|
|
|
if env_dict is not None:
|
|
|
|
|
|
env.update(env_dict)
|
2026-03-10 09:52:50 +08:00
|
|
|
|
proc = subprocess.Popen(
|
|
|
|
|
|
server_cmd, env=env, stdout=subprocess.PIPE, stderr=subprocess.PIPE, universal_newlines=True, bufsize=1
|
|
|
|
|
|
)
|
|
|
|
|
|
stdout_thread = threading.Thread(target=self._read_output, args=(proc.stdout, log_prefix), daemon=True)
|
|
|
|
|
|
stderr_thread = threading.Thread(target=self._read_output, args=(proc.stderr, log_prefix), daemon=True)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
|
|
|
|
|
stdout_thread.start()
|
|
|
|
|
|
stderr_thread.start()
|
|
|
|
|
|
return proc
|
|
|
|
|
|
|
|
|
|
|
|
def _terminate_server(self) -> None:
|
|
|
|
|
|
"""kill process and its children"""
|
|
|
|
|
|
print("vllm instance is stopping")
|
|
|
|
|
|
for proc in self._proc_list:
|
|
|
|
|
|
parent = psutil.Process(proc.pid)
|
|
|
|
|
|
children = parent.children(recursive=True)
|
|
|
|
|
|
for child in children:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
with contextlib.suppress(psutil.NoSuchProcess):
|
2026-02-06 17:30:17 +08:00
|
|
|
|
child.terminate()
|
|
|
|
|
|
|
|
|
|
|
|
gone, still_alive = psutil.wait_procs(children, timeout=10)
|
|
|
|
|
|
|
|
|
|
|
|
for child in still_alive:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
with contextlib.suppress(psutil.NoSuchProcess):
|
2026-02-06 17:30:17 +08:00
|
|
|
|
child.kill()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
parent.terminate()
|
|
|
|
|
|
parent.wait(timeout=10)
|
|
|
|
|
|
except (psutil.NoSuchProcess, psutil.TimeoutExpired):
|
2026-03-10 09:52:50 +08:00
|
|
|
|
with contextlib.suppress(psutil.NoSuchProcess):
|
2026-02-06 17:30:17 +08:00
|
|
|
|
parent.kill()
|
|
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
"""Context manager entry point."""
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
|
"""Context manager exit point - clean up all processes."""
|
|
|
|
|
|
self._terminate_server()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DisaggEpdProxy(RemoteEPDServer):
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
proxy_args: list[str] | str | None = None,
|
|
|
|
|
|
env_dict: dict[str, str] | None = None,
|
|
|
|
|
|
server_host: str = "0.0.0.0",
|
|
|
|
|
|
max_wait_seconds: float | None = 2800,
|
|
|
|
|
|
) -> None:
|
2026-02-06 17:30:17 +08:00
|
|
|
|
if proxy_args is None:
|
|
|
|
|
|
proxy_args_list: list[str] = []
|
|
|
|
|
|
elif isinstance(proxy_args, str):
|
|
|
|
|
|
proxy_args_list = shlex.split(proxy_args)
|
|
|
|
|
|
else:
|
|
|
|
|
|
proxy_args_list = proxy_args
|
|
|
|
|
|
|
|
|
|
|
|
self.proxy_args = proxy_args_list
|
|
|
|
|
|
self.env_dict: dict[str, str] = {}
|
|
|
|
|
|
if env_dict is not None:
|
|
|
|
|
|
self.env_dict.update(env_dict)
|
|
|
|
|
|
self._proc_list = list()
|
|
|
|
|
|
self.host = server_host
|
|
|
|
|
|
|
|
|
|
|
|
print(f"proxy param is: {self.proxy_args}")
|
|
|
|
|
|
proxy_cmd = ["python", str(DISAGG_EPD_PROXY_SCRIPT), *self.proxy_args]
|
|
|
|
|
|
proc = self._start_server_with_prefix(proxy_cmd, self.env_dict, "[PROXY] ")
|
|
|
|
|
|
self._proc_list.append(proc)
|
|
|
|
|
|
|
|
|
|
|
|
if "--port" not in self.proxy_args:
|
|
|
|
|
|
raise ValueError("You have manually specified the port ")
|
|
|
|
|
|
else:
|
|
|
|
|
|
try:
|
|
|
|
|
|
index = self.proxy_args.index("--port")
|
|
|
|
|
|
except ValueError:
|
|
|
|
|
|
raise ValueError("--port not found in proxy args")
|
|
|
|
|
|
port_str = self.proxy_args[index + 1]
|
|
|
|
|
|
self.port = int(port_str)
|
|
|
|
|
|
|
|
|
|
|
|
timeout_value = float(max_wait_seconds) if max_wait_seconds is not None else 2800.0
|
2026-03-10 09:52:50 +08:00
|
|
|
|
super()._wait_for_multiple_servers([(self.host, super().url_for("health"))], timeout=timeout_value)
|
2026-02-06 17:30:17 +08:00
|
|
|
|
|
|
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
"""Context manager entry point."""
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
|
|
|
|
"""Context manager exit point - clean up all processes."""
|
|
|
|
|
|
super()._terminate_server()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-23 15:44:21 +08:00
|
|
|
|
_DP_RUNNER_START_TIMEOUT_SECONDS = 900.0
|
|
|
|
|
|
_DP_RUNNER_REQUEST_TIMEOUT_SECONDS = 900.0
|
|
|
|
|
|
_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS = 30.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _split_data_parallel_indices(num_items: int, dp_size: int) -> list[list[int]]:
|
|
|
|
|
|
if num_items < 0:
|
|
|
|
|
|
raise ValueError("num_items must be non-negative")
|
|
|
|
|
|
if dp_size <= 0:
|
|
|
|
|
|
raise ValueError("dp_size must be positive")
|
|
|
|
|
|
|
|
|
|
|
|
floor = num_items // dp_size
|
|
|
|
|
|
remainder = num_items % dp_size
|
|
|
|
|
|
|
|
|
|
|
|
def start(rank: int) -> int:
|
|
|
|
|
|
return rank * floor + min(rank, remainder)
|
|
|
|
|
|
|
|
|
|
|
|
return [list(range(start(rank), start(rank + 1))) for rank in range(dp_size)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _slice_optional_inputs(inputs: PromptImageInput | PromptAudioInput | PromptVideoInput | None, indices: list[int]):
|
|
|
|
|
|
if inputs is None:
|
|
|
|
|
|
return None
|
|
|
|
|
|
return [inputs[index] for index in indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _slice_list_inputs(items: list[Any], indices: list[int]) -> list[Any]:
|
|
|
|
|
|
return [items[index] for index in indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _merge_data_parallel_results(total_items: int, shard_results: list[tuple[list[int], list[Any]]]) -> list[Any]:
|
|
|
|
|
|
merged: list[Any] = [None] * total_items
|
|
|
|
|
|
for indices, results in shard_results:
|
|
|
|
|
|
if not indices:
|
|
|
|
|
|
continue
|
|
|
|
|
|
if len(indices) != len(results):
|
|
|
|
|
|
raise RuntimeError("Mismatched result count returned by data parallel worker")
|
|
|
|
|
|
for index, result in zip(indices, results):
|
|
|
|
|
|
merged[index] = result
|
|
|
|
|
|
|
|
|
|
|
|
if any(result is None for result in merged):
|
|
|
|
|
|
raise RuntimeError("Some data parallel results were not returned")
|
|
|
|
|
|
|
|
|
|
|
|
return merged
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _normalize_score_inputs(text_1: str | list[str], text_2: str | list[str]) -> tuple[list[str], list[str]]:
|
|
|
|
|
|
if isinstance(text_1, str) and isinstance(text_2, str):
|
|
|
|
|
|
return [text_1], [text_2]
|
|
|
|
|
|
if isinstance(text_1, str):
|
|
|
|
|
|
return [text_1] * len(text_2), list(text_2)
|
|
|
|
|
|
if isinstance(text_2, str):
|
|
|
|
|
|
return list(text_1), [text_2] * len(text_1)
|
|
|
|
|
|
if len(text_1) != len(text_2):
|
|
|
|
|
|
raise ValueError("`text_1` and `text_2` must have the same length")
|
|
|
|
|
|
return list(text_1), list(text_2)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _run_vllm_runner_dp_worker(conn, llm_kwargs: dict[str, Any], dp_rank: int, dp_size: int, master_port: int) -> None:
|
|
|
|
|
|
llm = None
|
|
|
|
|
|
try:
|
|
|
|
|
|
os.environ["VLLM_DP_RANK"] = str(dp_rank)
|
|
|
|
|
|
os.environ["VLLM_DP_RANK_LOCAL"] = str(dp_rank)
|
|
|
|
|
|
os.environ["VLLM_DP_SIZE"] = str(dp_size)
|
|
|
|
|
|
os.environ["VLLM_DP_MASTER_IP"] = "127.0.0.1"
|
|
|
|
|
|
os.environ["VLLM_DP_MASTER_PORT"] = str(master_port)
|
|
|
|
|
|
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
|
|
|
|
|
|
|
|
|
|
|
llm = LLM(**llm_kwargs)
|
|
|
|
|
|
conn.send({"status": "ready", "rank": dp_rank})
|
|
|
|
|
|
|
|
|
|
|
|
while True:
|
|
|
|
|
|
request = conn.recv()
|
|
|
|
|
|
command = request["command"]
|
|
|
|
|
|
if command == "shutdown":
|
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
result: Any
|
|
|
|
|
|
if command == "generate":
|
|
|
|
|
|
req_outputs = llm.generate(
|
|
|
|
|
|
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
|
|
|
|
|
)
|
|
|
|
|
|
result = VllmRunner._finalize_generate_outputs(req_outputs)
|
|
|
|
|
|
elif command == "generate_w_logprobs":
|
|
|
|
|
|
req_outputs = llm.generate(
|
|
|
|
|
|
request["inputs"], sampling_params=request["sampling_params"], **request["kwargs"]
|
|
|
|
|
|
)
|
|
|
|
|
|
result = VllmRunner._final_steps_generate_w_logprobs(req_outputs)
|
|
|
|
|
|
elif command == "classify":
|
|
|
|
|
|
req_outputs = llm.classify(request["prompts"])
|
|
|
|
|
|
result = [req_output.outputs.probs for req_output in req_outputs]
|
|
|
|
|
|
elif command == "embed":
|
|
|
|
|
|
req_outputs = llm.embed(request["inputs"], *request["args"], **request["kwargs"])
|
|
|
|
|
|
result = [req_output.outputs.embedding for req_output in req_outputs]
|
|
|
|
|
|
elif command == "encode":
|
|
|
|
|
|
req_outputs = llm.encode(request["prompts"])
|
|
|
|
|
|
result = [req_output.outputs.data for req_output in req_outputs]
|
|
|
|
|
|
elif command == "reward":
|
|
|
|
|
|
req_outputs = llm.reward(request["prompts"])
|
|
|
|
|
|
result = [req_output.outputs.data for req_output in req_outputs]
|
|
|
|
|
|
elif command == "score":
|
|
|
|
|
|
req_outputs = llm.score(request["text_1"], request["text_2"], *request["args"], **request["kwargs"])
|
|
|
|
|
|
result = [req_output.outputs.score for req_output in req_outputs]
|
|
|
|
|
|
else:
|
|
|
|
|
|
raise ValueError(f"Unsupported data parallel command: {command}")
|
|
|
|
|
|
|
|
|
|
|
|
conn.send({"status": "ok", "rank": dp_rank, "indices": request["indices"], "result": result})
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
|
conn.send({"status": "error", "rank": dp_rank, "traceback": traceback.format_exc()})
|
|
|
|
|
|
raise
|
|
|
|
|
|
finally:
|
|
|
|
|
|
if llm is not None:
|
|
|
|
|
|
del llm
|
|
|
|
|
|
clear_ascend_config()
|
|
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
class VllmRunner:
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model_name: str,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
runner: RunnerOption = "auto",
|
|
|
|
|
|
convert: ConvertOption = "auto",
|
2026-03-10 09:52:50 +08:00
|
|
|
|
tokenizer_name: str | None = None,
|
2025-02-05 10:53:12 +08:00
|
|
|
|
tokenizer_mode: str = "auto",
|
2026-03-10 09:52:50 +08:00
|
|
|
|
max_model_len: int | None = 1024,
|
2025-09-02 09:02:22 +08:00
|
|
|
|
dtype: str = "auto",
|
2025-02-05 10:53:12 +08:00
|
|
|
|
disable_log_stats: bool = True,
|
|
|
|
|
|
tensor_parallel_size: int = 1,
|
|
|
|
|
|
block_size: int = 16,
|
2025-12-02 08:54:34 +08:00
|
|
|
|
enable_chunked_prefill: bool = True,
|
2025-02-05 10:53:12 +08:00
|
|
|
|
swap_space: int = 4,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
enforce_eager: bool | None = False,
|
|
|
|
|
|
quantization: str | None = None,
|
2025-02-05 10:53:12 +08:00
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> None:
|
2026-03-23 15:44:21 +08:00
|
|
|
|
data_parallel_size = int(kwargs.get("data_parallel_size", 1))
|
|
|
|
|
|
if data_parallel_size > 1:
|
|
|
|
|
|
raise ValueError("VllmRunner does not support `data_parallel_size > 1`; use `DPVllmRunner` instead.")
|
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
self.model = LLM(
|
|
|
|
|
|
model=model_name,
|
2025-12-01 19:01:55 +08:00
|
|
|
|
runner=runner,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
convert=convert,
|
2025-02-05 10:53:12 +08:00
|
|
|
|
tokenizer=tokenizer_name,
|
|
|
|
|
|
tokenizer_mode=tokenizer_mode,
|
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
|
swap_space=swap_space,
|
|
|
|
|
|
enforce_eager=enforce_eager,
|
|
|
|
|
|
disable_log_stats=disable_log_stats,
|
|
|
|
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
|
|
|
|
max_model_len=max_model_len,
|
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
2025-06-10 10:07:36 +08:00
|
|
|
|
quantization=quantization,
|
2025-02-05 10:53:12 +08:00
|
|
|
|
**kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
2026-03-23 15:44:21 +08:00
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _finalize_generate_outputs(req_outputs: list[RequestOutput]) -> list[tuple[list[list[int]], list[str]]]:
|
|
|
|
|
|
outputs: list[tuple[list[list[int]], list[str]]] = []
|
|
|
|
|
|
for req_output in req_outputs:
|
|
|
|
|
|
prompt_str = req_output.prompt
|
|
|
|
|
|
prompt_ids = req_output.prompt_token_ids
|
|
|
|
|
|
req_sample_output_ids: list[list[int]] = []
|
|
|
|
|
|
req_sample_output_strs: list[str] = []
|
|
|
|
|
|
for sample in req_output.outputs:
|
|
|
|
|
|
output_str = sample.text
|
|
|
|
|
|
output_ids = list(sample.token_ids)
|
|
|
|
|
|
req_sample_output_ids.append(prompt_ids + output_ids)
|
|
|
|
|
|
req_sample_output_strs.append((prompt_str or "") + output_str)
|
|
|
|
|
|
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
def get_inputs(
|
|
|
|
|
|
self,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
prompts: list[str] | list[torch.Tensor] | list[int],
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
) -> list[TextPrompt]:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
if any(x is not None and len(x) != len(prompts) for x in [images, videos, audios]):
|
|
|
|
|
|
raise ValueError("All non-None multimodal inputs must have the same length as prompts")
|
2025-12-10 11:37:57 +08:00
|
|
|
|
|
|
|
|
|
|
inputs = []
|
|
|
|
|
|
for i, prompt in enumerate(prompts):
|
|
|
|
|
|
multi_modal_data = {}
|
|
|
|
|
|
if images is not None and (image := images[i]) is not None:
|
|
|
|
|
|
multi_modal_data["image"] = image
|
|
|
|
|
|
if videos is not None and (video := videos[i]) is not None:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
multi_modal_data["video"] = video # type: ignore
|
2025-12-10 11:37:57 +08:00
|
|
|
|
if audios is not None and (audio := audios[i]) is not None:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
multi_modal_data["audio"] = audio # type: ignore
|
2025-12-10 11:37:57 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
text_prompt_kwargs: dict[str, Any] = {"multi_modal_data": multi_modal_data or None}
|
2025-12-10 11:37:57 +08:00
|
|
|
|
if isinstance(prompt, str):
|
|
|
|
|
|
text_prompt_kwargs["prompt"] = prompt
|
|
|
|
|
|
elif isinstance(prompt, list):
|
|
|
|
|
|
text_prompt_kwargs["prompt_token_ids"] = prompt
|
|
|
|
|
|
else:
|
|
|
|
|
|
text_prompt_kwargs["prompt_embeds"] = prompt
|
|
|
|
|
|
|
|
|
|
|
|
inputs.append(TextPrompt(**text_prompt_kwargs))
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
return inputs
|
|
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
|
self,
|
2026-03-23 15:44:21 +08:00
|
|
|
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
2025-02-05 10:53:12 +08:00
|
|
|
|
sampling_params: SamplingParams,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> list[tuple[list[list[int]], list[str]]]:
|
2026-03-10 09:52:50 +08:00
|
|
|
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
|
|
|
|
|
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
2026-03-23 15:44:21 +08:00
|
|
|
|
return self._finalize_generate_outputs(req_outputs)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
|
def _final_steps_generate_w_logprobs(
|
2025-12-10 11:37:57 +08:00
|
|
|
|
req_outputs: list[RequestOutput],
|
|
|
|
|
|
) -> list[TokensTextLogprobsPromptLogprobs]:
|
|
|
|
|
|
outputs: list[TokensTextLogprobsPromptLogprobs] = []
|
2025-02-05 10:53:12 +08:00
|
|
|
|
for req_output in req_outputs:
|
|
|
|
|
|
assert len(req_output.outputs) > 0
|
|
|
|
|
|
for sample in req_output.outputs:
|
|
|
|
|
|
output_str = sample.text
|
|
|
|
|
|
output_ids = list(sample.token_ids)
|
|
|
|
|
|
output_logprobs = sample.logprobs
|
2026-03-10 09:52:50 +08:00
|
|
|
|
outputs.append((output_ids, output_str, output_logprobs, req_output.prompt_logprobs))
|
2025-02-05 10:53:12 +08:00
|
|
|
|
return outputs
|
|
|
|
|
|
|
|
|
|
|
|
def generate_w_logprobs(
|
|
|
|
|
|
self,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
prompts: list[str],
|
2025-02-05 10:53:12 +08:00
|
|
|
|
sampling_params: SamplingParams,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
**kwargs: Any,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
|
|
|
|
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
|
|
|
|
|
|
|
|
|
|
|
req_outputs = self.model.generate(inputs, sampling_params=sampling_params, **kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
toks_str_logsprobs_prompt_logprobs = self._final_steps_generate_w_logprobs(req_outputs)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
# Omit prompt logprobs if not required by sampling params
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return (
|
|
|
|
|
|
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
|
|
|
|
|
if sampling_params.prompt_logprobs is None
|
|
|
|
|
|
else toks_str_logsprobs_prompt_logprobs
|
|
|
|
|
|
)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
def generate_greedy(
|
|
|
|
|
|
self,
|
2026-03-23 15:44:21 +08:00
|
|
|
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
2025-02-05 10:53:12 +08:00
|
|
|
|
max_tokens: int,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> list[tuple[list[int], str]]:
|
2025-02-05 10:53:12 +08:00
|
|
|
|
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
2026-03-10 09:52:50 +08:00
|
|
|
|
outputs = self.generate(prompts, greedy_params, images=images, videos=videos, audios=audios, **kwargs)
|
|
|
|
|
|
return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs]
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
|
|
|
|
|
def generate_greedy_logprobs(
|
|
|
|
|
|
self,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
prompts: list[str],
|
2025-02-05 10:53:12 +08:00
|
|
|
|
max_tokens: int,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
num_logprobs: int | None,
|
|
|
|
|
|
num_prompt_logprobs: int | None = None,
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
stop_token_ids: list[int] | None = None,
|
|
|
|
|
|
stop: list[str] | None = None,
|
2025-12-10 11:37:57 +08:00
|
|
|
|
**kwargs: Any,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
2025-02-05 10:53:12 +08:00
|
|
|
|
greedy_logprobs_params = SamplingParams(
|
|
|
|
|
|
temperature=0.0,
|
|
|
|
|
|
max_tokens=max_tokens,
|
|
|
|
|
|
logprobs=num_logprobs,
|
|
|
|
|
|
prompt_logprobs=num_prompt_logprobs,
|
|
|
|
|
|
stop_token_ids=stop_token_ids,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
stop=stop,
|
|
|
|
|
|
)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return self.generate_w_logprobs(
|
|
|
|
|
|
prompts, greedy_logprobs_params, images=images, audios=audios, videos=videos, **kwargs
|
|
|
|
|
|
)
|
2025-12-10 11:37:57 +08:00
|
|
|
|
|
|
|
|
|
|
def classify(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
req_outputs = self.model.classify(prompts)
|
|
|
|
|
|
return [req_output.outputs.probs for req_output in req_outputs]
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def embed(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompts: list[str],
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
*args,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> list[list[float]]:
|
|
|
|
|
|
inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
|
2025-12-10 11:37:57 +08:00
|
|
|
|
req_outputs = self.model.embed(inputs, *args, **kwargs)
|
2025-02-05 10:53:12 +08:00
|
|
|
|
return [req_output.outputs.embedding for req_output in req_outputs]
|
|
|
|
|
|
|
2025-12-10 11:37:57 +08:00
|
|
|
|
def encode(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
req_outputs = self.model.encode(prompts)
|
|
|
|
|
|
return [req_output.outputs.data for req_output in req_outputs]
|
|
|
|
|
|
|
|
|
|
|
|
def reward(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
req_outputs = self.model.reward(prompts)
|
|
|
|
|
|
return [req_output.outputs.data for req_output in req_outputs]
|
|
|
|
|
|
|
|
|
|
|
|
def score(
|
|
|
|
|
|
self,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
text_1: str | list[str],
|
|
|
|
|
|
text_2: str | list[str],
|
2025-12-10 11:37:57 +08:00
|
|
|
|
*args,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> list[float]:
|
|
|
|
|
|
req_outputs = self.model.score(text_1, text_2, *args, **kwargs)
|
|
|
|
|
|
return [req_output.outputs.score for req_output in req_outputs]
|
|
|
|
|
|
|
2025-02-05 10:53:12 +08:00
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
|
del self.model
|
2025-09-02 18:49:17 +08:00
|
|
|
|
clear_ascend_config()
|
2025-02-05 10:53:12 +08:00
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
2026-03-23 15:44:21 +08:00
|
|
|
|
class DPVllmRunner(VllmRunner):
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
runner: RunnerOption = "auto",
|
|
|
|
|
|
convert: ConvertOption = "auto",
|
|
|
|
|
|
tokenizer_name: str | None = None,
|
|
|
|
|
|
tokenizer_mode: str = "auto",
|
|
|
|
|
|
max_model_len: int | None = 1024,
|
|
|
|
|
|
dtype: str = "auto",
|
|
|
|
|
|
disable_log_stats: bool = True,
|
|
|
|
|
|
tensor_parallel_size: int = 1,
|
|
|
|
|
|
block_size: int = 16,
|
|
|
|
|
|
enable_chunked_prefill: bool = True,
|
|
|
|
|
|
swap_space: int = 4,
|
|
|
|
|
|
enforce_eager: bool | None = False,
|
|
|
|
|
|
quantization: str | None = None,
|
|
|
|
|
|
data_parallel_size: int = 2,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
if data_parallel_size < 2:
|
|
|
|
|
|
raise ValueError("DPVllmRunner requires `data_parallel_size >= 2`")
|
|
|
|
|
|
|
|
|
|
|
|
self._dp_size = data_parallel_size
|
|
|
|
|
|
self._dp_parent_conns: list[Any] = []
|
|
|
|
|
|
self._dp_processes: list[Any] = []
|
|
|
|
|
|
self._dp_start_timeout = float(kwargs.pop("dp_start_timeout", _DP_RUNNER_START_TIMEOUT_SECONDS))
|
|
|
|
|
|
self._dp_request_timeout = float(kwargs.pop("dp_request_timeout", _DP_RUNNER_REQUEST_TIMEOUT_SECONDS))
|
|
|
|
|
|
|
|
|
|
|
|
llm_kwargs = dict(
|
|
|
|
|
|
model=model_name,
|
|
|
|
|
|
runner=runner,
|
|
|
|
|
|
convert=convert,
|
|
|
|
|
|
tokenizer=tokenizer_name,
|
|
|
|
|
|
tokenizer_mode=tokenizer_mode,
|
|
|
|
|
|
trust_remote_code=True,
|
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
|
swap_space=swap_space,
|
|
|
|
|
|
enforce_eager=enforce_eager,
|
|
|
|
|
|
disable_log_stats=disable_log_stats,
|
|
|
|
|
|
tensor_parallel_size=tensor_parallel_size,
|
|
|
|
|
|
max_model_len=max_model_len,
|
|
|
|
|
|
block_size=block_size,
|
|
|
|
|
|
enable_chunked_prefill=enable_chunked_prefill,
|
|
|
|
|
|
quantization=quantization,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
self._start_data_parallel_workers(llm_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
|
def model(self) -> LLM:
|
|
|
|
|
|
raise RuntimeError("Direct access to `runner.model` is not supported by `DPVllmRunner`.")
|
|
|
|
|
|
|
|
|
|
|
|
def _start_data_parallel_workers(self, llm_kwargs: dict[str, Any]) -> None:
|
|
|
|
|
|
ctx = multiprocessing.get_context("spawn")
|
|
|
|
|
|
master_port = get_open_port()
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for dp_rank in range(self._dp_size):
|
|
|
|
|
|
parent_conn, child_conn = ctx.Pipe()
|
|
|
|
|
|
proc = ctx.Process(
|
|
|
|
|
|
target=_run_vllm_runner_dp_worker,
|
|
|
|
|
|
args=(child_conn, llm_kwargs, dp_rank, self._dp_size, master_port),
|
|
|
|
|
|
)
|
|
|
|
|
|
proc.start()
|
|
|
|
|
|
child_conn.close()
|
|
|
|
|
|
self._dp_parent_conns.append(parent_conn)
|
|
|
|
|
|
self._dp_processes.append(proc)
|
|
|
|
|
|
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
if not conn.poll(self._dp_start_timeout):
|
|
|
|
|
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to start")
|
|
|
|
|
|
message = conn.recv()
|
|
|
|
|
|
if message["status"] != "ready":
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
f"Failed to start data parallel worker {rank}:\n{message.get('traceback', 'unknown error')}"
|
|
|
|
|
|
)
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
self._stop_data_parallel_workers()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
def _stop_data_parallel_workers(self) -> None:
|
|
|
|
|
|
for conn in self._dp_parent_conns:
|
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
|
conn.send({"command": "shutdown"})
|
|
|
|
|
|
|
|
|
|
|
|
for proc in self._dp_processes:
|
|
|
|
|
|
proc.join(timeout=_DP_RUNNER_SHUTDOWN_TIMEOUT_SECONDS)
|
|
|
|
|
|
if proc.is_alive():
|
|
|
|
|
|
proc.kill()
|
|
|
|
|
|
proc.join(timeout=5)
|
|
|
|
|
|
|
|
|
|
|
|
for conn in self._dp_parent_conns:
|
|
|
|
|
|
with contextlib.suppress(Exception):
|
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
|
|
self._dp_parent_conns.clear()
|
|
|
|
|
|
self._dp_processes.clear()
|
|
|
|
|
|
|
|
|
|
|
|
def _dispatch_prompt_command(
|
|
|
|
|
|
self,
|
|
|
|
|
|
command: str,
|
|
|
|
|
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
|
|
|
|
|
*,
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
**payload: Any,
|
|
|
|
|
|
) -> list[Any]:
|
|
|
|
|
|
if not prompts:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
|
|
|
|
|
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
|
|
|
|
|
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
indices = shard_indices[rank]
|
|
|
|
|
|
worker_indices = indices or [0]
|
|
|
|
|
|
worker_prompts = _slice_list_inputs(prompts, worker_indices)
|
|
|
|
|
|
conn.send(
|
|
|
|
|
|
{
|
|
|
|
|
|
"command": command,
|
|
|
|
|
|
"indices": indices,
|
|
|
|
|
|
"inputs": self.get_inputs(
|
|
|
|
|
|
worker_prompts,
|
|
|
|
|
|
images=_slice_optional_inputs(images, worker_indices),
|
|
|
|
|
|
videos=_slice_optional_inputs(videos, worker_indices),
|
|
|
|
|
|
audios=_slice_optional_inputs(audios, worker_indices),
|
|
|
|
|
|
),
|
|
|
|
|
|
"prompts": worker_prompts,
|
|
|
|
|
|
**payload,
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
if not conn.poll(self._dp_request_timeout):
|
|
|
|
|
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
|
|
|
|
|
message = conn.recv()
|
|
|
|
|
|
if message["status"] != "ok":
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
f"Data parallel worker {rank} failed during `{command}`:\n"
|
|
|
|
|
|
f"{message.get('traceback', 'unknown error')}"
|
|
|
|
|
|
)
|
|
|
|
|
|
shard_results.append((message["indices"], message["result"]))
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
self._stop_data_parallel_workers()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
return _merge_data_parallel_results(len(prompts), shard_results)
|
|
|
|
|
|
|
|
|
|
|
|
def _dispatch_text_command(self, command: str, prompts: list[str]) -> list[Any]:
|
|
|
|
|
|
if not prompts:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
|
|
|
|
|
shard_indices = _split_data_parallel_indices(len(prompts), self._dp_size)
|
|
|
|
|
|
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
indices = shard_indices[rank]
|
|
|
|
|
|
worker_indices = indices or [0]
|
|
|
|
|
|
conn.send(
|
|
|
|
|
|
{
|
|
|
|
|
|
"command": command,
|
|
|
|
|
|
"indices": indices,
|
|
|
|
|
|
"prompts": _slice_list_inputs(prompts, worker_indices),
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
if not conn.poll(self._dp_request_timeout):
|
|
|
|
|
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `{command}`")
|
|
|
|
|
|
message = conn.recv()
|
|
|
|
|
|
if message["status"] != "ok":
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
f"Data parallel worker {rank} failed during `{command}`:\n"
|
|
|
|
|
|
f"{message.get('traceback', 'unknown error')}"
|
|
|
|
|
|
)
|
|
|
|
|
|
shard_results.append((message["indices"], message["result"]))
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
self._stop_data_parallel_workers()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
return _merge_data_parallel_results(len(prompts), shard_results)
|
|
|
|
|
|
|
|
|
|
|
|
def generate(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompts: list[str] | list[torch.Tensor] | list[list[int]],
|
|
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> list[tuple[list[list[int]], list[str]]]:
|
|
|
|
|
|
return self._dispatch_prompt_command(
|
|
|
|
|
|
"generate",
|
|
|
|
|
|
prompts,
|
|
|
|
|
|
images=images,
|
|
|
|
|
|
videos=videos,
|
|
|
|
|
|
audios=audios,
|
|
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def generate_w_logprobs(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompts: list[str],
|
|
|
|
|
|
sampling_params: SamplingParams,
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
|
) -> list[TokensTextLogprobs] | list[TokensTextLogprobsPromptLogprobs]:
|
|
|
|
|
|
toks_str_logsprobs_prompt_logprobs = self._dispatch_prompt_command(
|
|
|
|
|
|
"generate_w_logprobs",
|
|
|
|
|
|
prompts,
|
|
|
|
|
|
images=images,
|
|
|
|
|
|
videos=videos,
|
|
|
|
|
|
audios=audios,
|
|
|
|
|
|
sampling_params=sampling_params,
|
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
return (
|
|
|
|
|
|
[x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
|
|
|
|
|
if sampling_params.prompt_logprobs is None
|
|
|
|
|
|
else toks_str_logsprobs_prompt_logprobs
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def classify(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
return self._dispatch_text_command("classify", prompts)
|
|
|
|
|
|
|
|
|
|
|
|
def embed(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompts: list[str],
|
|
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
*args,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> list[list[float]]:
|
|
|
|
|
|
return self._dispatch_prompt_command(
|
|
|
|
|
|
"embed",
|
|
|
|
|
|
prompts,
|
|
|
|
|
|
images=images,
|
|
|
|
|
|
videos=videos,
|
|
|
|
|
|
audios=audios,
|
|
|
|
|
|
args=args,
|
|
|
|
|
|
kwargs=kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
def encode(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
return self._dispatch_text_command("encode", prompts)
|
|
|
|
|
|
|
|
|
|
|
|
def reward(self, prompts: list[str]) -> list[list[float]]:
|
|
|
|
|
|
return self._dispatch_text_command("reward", prompts)
|
|
|
|
|
|
|
|
|
|
|
|
def score(
|
|
|
|
|
|
self,
|
|
|
|
|
|
text_1: str | list[str],
|
|
|
|
|
|
text_2: str | list[str],
|
|
|
|
|
|
*args,
|
|
|
|
|
|
**kwargs,
|
|
|
|
|
|
) -> list[float]:
|
|
|
|
|
|
normalized_text_1, normalized_text_2 = _normalize_score_inputs(text_1, text_2)
|
|
|
|
|
|
if not normalized_text_1:
|
|
|
|
|
|
return []
|
|
|
|
|
|
|
|
|
|
|
|
shard_results: list[tuple[list[int], list[Any]]] = []
|
|
|
|
|
|
shard_indices = _split_data_parallel_indices(len(normalized_text_1), self._dp_size)
|
|
|
|
|
|
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
indices = shard_indices[rank]
|
|
|
|
|
|
worker_indices = indices or [0]
|
|
|
|
|
|
conn.send(
|
|
|
|
|
|
{
|
|
|
|
|
|
"command": "score",
|
|
|
|
|
|
"indices": indices,
|
|
|
|
|
|
"text_1": _slice_list_inputs(normalized_text_1, worker_indices),
|
|
|
|
|
|
"text_2": _slice_list_inputs(normalized_text_2, worker_indices),
|
|
|
|
|
|
"args": args,
|
|
|
|
|
|
"kwargs": kwargs,
|
|
|
|
|
|
}
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
|
for rank, conn in enumerate(self._dp_parent_conns):
|
|
|
|
|
|
if not conn.poll(self._dp_request_timeout):
|
|
|
|
|
|
raise TimeoutError(f"Timed out waiting for data parallel worker {rank} to finish `score`")
|
|
|
|
|
|
message = conn.recv()
|
|
|
|
|
|
if message["status"] != "ok":
|
|
|
|
|
|
raise RuntimeError(
|
|
|
|
|
|
f"Data parallel worker {rank} failed during `score`:\n"
|
|
|
|
|
|
f"{message.get('traceback', 'unknown error')}"
|
|
|
|
|
|
)
|
|
|
|
|
|
shard_results.append((message["indices"], message["result"]))
|
|
|
|
|
|
except Exception:
|
|
|
|
|
|
self._stop_data_parallel_workers()
|
|
|
|
|
|
raise
|
|
|
|
|
|
|
|
|
|
|
|
return _merge_data_parallel_results(len(normalized_text_1), shard_results)
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
|
self._stop_data_parallel_workers()
|
|
|
|
|
|
clear_ascend_config()
|
|
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
DataParallelVllmRunner = DPVllmRunner
|
|
|
|
|
|
|
|
|
|
|
|
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
class HfRunner:
|
|
|
|
|
|
def get_default_device(self):
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return "cpu" if current_platform.is_cpu() else current_platform.device_type
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def wrap_device(self, x: _T, device: str | None = None) -> _T:
|
|
|
|
|
|
if x is None or isinstance(x, (bool,)):
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
if device is None:
|
|
|
|
|
|
device = self.device
|
|
|
|
|
|
|
|
|
|
|
|
if isinstance(x, dict):
|
|
|
|
|
|
return {k: self.wrap_device(v, device) for k, v in x.items()}
|
|
|
|
|
|
|
|
|
|
|
|
if hasattr(x, "device") and x.device.type == device:
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
|
|
|
return x.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
|
|
self,
|
|
|
|
|
|
model_name: str,
|
|
|
|
|
|
dtype: str = "auto",
|
|
|
|
|
|
*,
|
2026-03-10 09:52:50 +08:00
|
|
|
|
model_kwargs: dict[str, Any] | None = None,
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
trust_remote_code: bool = True,
|
|
|
|
|
|
is_sentence_transformer: bool = False,
|
|
|
|
|
|
is_cross_encoder: bool = False,
|
|
|
|
|
|
skip_tokenizer_init: bool = False,
|
|
|
|
|
|
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
|
|
|
|
|
|
) -> None:
|
|
|
|
|
|
model_name = maybe_model_redirect(model_name)
|
|
|
|
|
|
self.model_name = model_name
|
|
|
|
|
|
|
|
|
|
|
|
self.config = AutoConfig.from_pretrained(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
)
|
|
|
|
|
|
self.device = self.get_default_device()
|
|
|
|
|
|
self.dtype = torch_dtype = _get_and_verify_dtype(
|
|
|
|
|
|
self.model_name,
|
|
|
|
|
|
self.config,
|
|
|
|
|
|
dtype=dtype,
|
|
|
|
|
|
is_pooling_model=is_sentence_transformer or is_cross_encoder,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
model_kwargs = model_kwargs if model_kwargs is not None else {}
|
|
|
|
|
|
model_kwargs.setdefault("torch_dtype", torch_dtype)
|
|
|
|
|
|
|
|
|
|
|
|
if is_sentence_transformer:
|
|
|
|
|
|
# Lazy init required for AMD CI
|
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
|
|
|
|
|
|
|
self.model = SentenceTransformer(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
device=self.device,
|
|
|
|
|
|
model_kwargs=model_kwargs,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
)
|
|
|
|
|
|
elif is_cross_encoder:
|
|
|
|
|
|
# Lazy init required for AMD CI
|
|
|
|
|
|
from sentence_transformers import CrossEncoder
|
|
|
|
|
|
|
|
|
|
|
|
self.model = CrossEncoder(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
device=self.device,
|
|
|
|
|
|
automodel_args=model_kwargs,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
|
|
|
model = auto_cls.from_pretrained(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
**model_kwargs,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# in case some unquantized custom models are not in same dtype
|
2026-03-10 09:52:50 +08:00
|
|
|
|
if getattr(model, "quantization_method", None) is None and any(
|
|
|
|
|
|
p.dtype != self.dtype for p in model.parameters()
|
|
|
|
|
|
):
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
model = model.to(dtype=self.dtype)
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
if (
|
|
|
|
|
|
getattr(model, "quantization_method", None) != "bitsandbytes"
|
|
|
|
|
|
and len({p.device for p in model.parameters()}) < 2
|
|
|
|
|
|
):
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
model = model.to(device=self.device)
|
|
|
|
|
|
|
|
|
|
|
|
self.model = model
|
|
|
|
|
|
|
|
|
|
|
|
if not skip_tokenizer_init:
|
|
|
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
# don't put this import at the top level
|
|
|
|
|
|
# it will call torch.cuda.device_count()
|
|
|
|
|
|
from transformers import AutoProcessor # noqa: F401
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
self.processor = AutoProcessor.from_pretrained(
|
|
|
|
|
|
model_name,
|
|
|
|
|
|
torch_dtype=torch_dtype,
|
|
|
|
|
|
trust_remote_code=trust_remote_code,
|
|
|
|
|
|
)
|
|
|
|
|
|
if skip_tokenizer_init:
|
|
|
|
|
|
self.tokenizer = self.processor.tokenizer
|
|
|
|
|
|
|
2025-12-10 11:37:57 +08:00
|
|
|
|
def get_inputs(
|
|
|
|
|
|
self,
|
|
|
|
|
|
prompts: list[str],
|
2026-03-10 09:52:50 +08:00
|
|
|
|
images: PromptImageInput | None = None,
|
|
|
|
|
|
videos: PromptVideoInput | None = None,
|
|
|
|
|
|
audios: PromptAudioInput | None = None,
|
|
|
|
|
|
) -> list[BatchFeature | BatchEncoding]:
|
2025-12-10 11:37:57 +08:00
|
|
|
|
if images is not None:
|
|
|
|
|
|
assert len(prompts) == len(images)
|
|
|
|
|
|
|
|
|
|
|
|
if videos is not None:
|
|
|
|
|
|
assert len(prompts) == len(videos)
|
|
|
|
|
|
|
|
|
|
|
|
if audios is not None:
|
|
|
|
|
|
assert len(prompts) == len(audios)
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
all_inputs: list[BatchFeature | BatchEncoding] = []
|
2025-12-10 11:37:57 +08:00
|
|
|
|
for i, prompt in enumerate(prompts):
|
|
|
|
|
|
processor_kwargs: dict[str, Any] = {
|
|
|
|
|
|
"text": prompt,
|
|
|
|
|
|
"return_tensors": "pt",
|
|
|
|
|
|
}
|
|
|
|
|
|
if images is not None and (image := images[i]) is not None:
|
|
|
|
|
|
processor_kwargs["images"] = image
|
|
|
|
|
|
if videos is not None and (video := videos[i]) is not None:
|
|
|
|
|
|
processor_kwargs["videos"] = video
|
|
|
|
|
|
if audios is not None and (audio_inputs := audios[i]) is not None:
|
|
|
|
|
|
# HACK - not all processors take sampling_rate; we should
|
|
|
|
|
|
# clean this up in the future.
|
|
|
|
|
|
if len(audio_inputs) == 2:
|
|
|
|
|
|
audio, sr = audio_inputs
|
|
|
|
|
|
processor_kwargs["audio"] = audio
|
|
|
|
|
|
processor_kwargs["sampling_rate"] = sr
|
|
|
|
|
|
else:
|
|
|
|
|
|
processor_kwargs["audio"] = audio_inputs
|
|
|
|
|
|
|
|
|
|
|
|
inputs = self.processor(**processor_kwargs)
|
|
|
|
|
|
if isinstance(inputs, BatchFeature):
|
|
|
|
|
|
inputs = inputs.to(dtype=self.dtype)
|
|
|
|
|
|
|
|
|
|
|
|
all_inputs.append(inputs)
|
|
|
|
|
|
|
|
|
|
|
|
return all_inputs
|
|
|
|
|
|
|
|
|
|
|
|
def classify(self, prompts: list[str]) -> list[str]:
|
|
|
|
|
|
# output is final logits
|
|
|
|
|
|
all_inputs = self.get_inputs(prompts)
|
|
|
|
|
|
outputs = []
|
|
|
|
|
|
problem_type = getattr(self.config, "problem_type", "")
|
|
|
|
|
|
|
|
|
|
|
|
for inputs in all_inputs:
|
|
|
|
|
|
output = self.model(**self.wrap_device(inputs))
|
|
|
|
|
|
if problem_type == "regression":
|
|
|
|
|
|
logits = output.logits[0].tolist()
|
|
|
|
|
|
elif problem_type == "multi_label_classification":
|
|
|
|
|
|
logits = output.logits.sigmoid()[0].tolist()
|
|
|
|
|
|
else:
|
|
|
|
|
|
logits = output.logits.softmax(dim=-1)[0].tolist()
|
|
|
|
|
|
outputs.append(logits)
|
|
|
|
|
|
|
|
|
|
|
|
return outputs
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def encode(self, prompts: list[str], *args, **kwargs) -> list[list[torch.Tensor]]:
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
return self.model.encode(prompts, *args, **kwargs)
|
|
|
|
|
|
|
2026-03-10 09:52:50 +08:00
|
|
|
|
def predict(self, prompts: list[list[str]], *args, **kwargs) -> torch.Tensor:
|
|
|
|
|
|
return self.model.predict(prompts, *args, convert_to_tensor=True, **kwargs)
|
2025-12-10 11:37:57 +08:00
|
|
|
|
|
[V1][ModelRunner] Support pooling model for v1 engine (#1359)
### What this PR does / why we need it?
Change as little existing code as possible to add v1 pooling task's
support, notice that i move down the `vllm.v1.worker.gpu_input_batch` to
vllm-ascend, Considering the frequent changes in upstream interfaces, in
order to decouple, so i move it here
### How was this patch tested?
CI passed with new added/existing test, and I have a simple test was
first conducted locally which is adapted from
https://www.modelscope.cn/models/Qwen/Qwen3-Embedding-0.6B, just like
bellow:
```python
import os
import torch
from vllm import LLM
os.environ["VLLM_USE_MODELSCOPE"]="True"
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
# Each query must come with a one-sentence instruction that describes the task
task = 'Given a web search query, retrieve relevant passages that answer the query'
queries = [
get_detailed_instruct(task, 'What is the capital of China?'),
get_detailed_instruct(task, 'Explain gravity')
]
# No need to add instruction for retrieval documents
documents = [
"The capital of China is Beijing.",
"Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun."
]
input_texts = queries + documents
model = LLM(model="Qwen/Qwen3-Embedding-0.6B", task="embed")
outputs = model.embed(input_texts)
embeddings = torch.tensor([o.outputs.embedding for o in outputs])
scores = (embeddings[:2] @ embeddings[2:].T)
print(scores.tolist())
# [[0.7620252966880798, 0.14078938961029053], [0.1358368694782257, 0.6013815999031067]]
```
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
Signed-off-by: wangli <858794774@qq.com>
Co-authored-by: wangli <858794774@qq.com>
2025-06-30 16:31:12 +08:00
|
|
|
|
def __enter__(self):
|
|
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
|
|
|
def __exit__(self, exc_type, exc_value, traceback):
|
|
|
|
|
|
del self.model
|
|
|
|
|
|
cleanup_dist_env_and_memory()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@pytest.fixture(scope="session")
|
2025-09-02 09:02:22 +08:00
|
|
|
|
def ilama_lora_files():
|
[CI][Misc] Use offline mode for model downloads (#7179)
### What this PR does / why we need it?
1. For all parts of the current test module involving the millisecond
download model, add the `local_file_only` parameter to specify offline
mode; this ensures that CI will not fail due to network instability.
2. Install modelscope from a fixed commit until it next release
### Does this PR introduce _any_ user-facing change?
### How was this patch tested?
check if the env or arg `local_files_only` works
1) set the env:
```shell
export HF_HUB_OFFLINE=1
```
2) run the script
```python
from transformers import PretrainedConfig
import huggingface_hub
from modelscope.utils.hf_util import patch_hub
patch_hub()
model="Qwen/Qwen3-0.6B"
kwargs = {}
config_dict, _ = PretrainedConfig.get_config_dict(
model,
trust_remote_code=True,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
**kwargs,
)
print(config_dict)
```
it works well:
```shell
2026-03-06 06:40:12,546 - modelscope - WARNING - We can not confirm the cached file is for revision: master
The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.
{'architectures': ['Qwen3ForCausalLM'], 'attention_bias': False, 'attention_dropout': 0.0, 'bos_token_id': 151643, 'eos_token_id': 151645, 'head_dim': 128, 'hidden_act': 'silu', 'hidden_size': 1024, 'initializer_range': 0.02, 'intermediate_size': 3072, 'max_position_embeddings': 40960, 'max_window_layers': 28, 'model_type': 'qwen3', 'num_attention_heads': 16, 'num_hidden_layers': 28, 'num_key_value_heads': 8, 'rms_norm_eps': 1e-06, 'rope_scaling': None, 'rope_theta': 1000000, 'sliding_window': None, 'tie_word_embeddings': True, 'torch_dtype': 'bfloat16', 'transformers_version': '4.51.0', 'use_cache': True, 'use_sliding_window': False, 'vocab_size': 151936, '_commit_hash': None}
```
3) test the model repo does not cached locally when the env
`HF_HUB_OFFLINE`==True
```python
from transformers import PretrainedConfig
import huggingface_hub
from modelscope.utils.hf_util import patch_hub
patch_hub()
model="FireRedTeam/FireRed-OCR"
kwargs = {}
config_dict, _ = PretrainedConfig.get_config_dict(
model,
trust_remote_code=True,
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
**kwargs,
)
print(config_dict)
```
and the result is as expected:
```shell
File "/workspace/demo.py", line 12, in <module>
config_dict, _ = PretrainedConfig.get_config_dict(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/utils/hf_util/patcher.py", line 189, in patch_get_config_dict
model_dir = get_model_dir(pretrained_model_name_or_path,
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/utils/hf_util/patcher.py", line 164, in get_model_dir
model_dir = snapshot_download(
^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/hub/snapshot_download.py", line 137, in snapshot_download
return _snapshot_download(
^^^^^^^^^^^^^^^^^^^
File "/usr/local/python3.11.14/lib/python3.11/site-packages/modelscope/hub/snapshot_download.py", line 283, in _snapshot_download
raise ValueError(
ValueError: Cannot find the requested files in the cached path and outgoing traffic has been disabled. To enable look-ups and downloads online, set 'local_files_only' to False
```
- vLLM version: v0.16.0
- vLLM main:
https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7
---------
Signed-off-by: wangli <wangli858794774@gmail.com>
2026-03-13 08:52:24 +08:00
|
|
|
|
return snapshot_download(
|
|
|
|
|
|
repo_id="vllm-ascend/ilama-text2sql-spider",
|
|
|
|
|
|
local_files_only=huggingface_hub.constants.HF_HUB_OFFLINE,
|
|
|
|
|
|
)
|
2025-09-02 09:02:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
2026-01-13 16:32:28 +08:00
|
|
|
|
@pytest.fixture(scope="session")
|
|
|
|
|
|
def llama32_lora_files():
|
2026-01-26 09:00:51 +08:00
|
|
|
|
from huggingface_hub import snapshot_download as hf_snapshot_download
|
2026-03-10 09:52:50 +08:00
|
|
|
|
|
2026-01-26 09:00:51 +08:00
|
|
|
|
return hf_snapshot_download(repo_id="jeeejeee/llama32-3b-text2sql-spider", local_files_only=True)
|
2026-01-13 16:32:28 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-10 11:37:57 +08:00
|
|
|
|
def qwen_prompt(questions: list[str]) -> list[str]:
|
2025-09-02 09:02:22 +08:00
|
|
|
|
placeholder = "<|image_pad|>"
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return [
|
|
|
|
|
|
(
|
|
|
|
|
|
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
|
|
|
|
|
|
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
|
|
|
|
|
|
f"{q}<|im_end|>\n<|im_start|>assistant\n"
|
|
|
|
|
|
)
|
|
|
|
|
|
for q in questions
|
|
|
|
|
|
]
|
2025-09-02 09:02:22 +08:00
|
|
|
|
|
|
|
|
|
|
|
2025-12-23 10:46:54 +08:00
|
|
|
|
def hunyuan_prompt(questions: list[str]) -> list[str]:
|
|
|
|
|
|
placeholder = "<|hy_place▁holder▁no▁100|><|hy_place▁holder▁no▁102|><|hy_place▁holder▁no▁101|>" # noqa: E501
|
2026-03-10 09:52:50 +08:00
|
|
|
|
return [f"<|hy_begin▁of▁sentence|>{placeholder}{question}<|hy_User|>" for question in questions]
|
2025-12-23 10:46:54 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROMPT_CONFIGS = {
|
|
|
|
|
|
"qwen-vl": {
|
|
|
|
|
|
"model": "Qwen/Qwen3-VL-8B-Instruct",
|
|
|
|
|
|
"prompt_fn": qwen_prompt,
|
|
|
|
|
|
"mm_processor_kwargs": {
|
|
|
|
|
|
"min_pixels": 28 * 28,
|
|
|
|
|
|
"max_pixels": 1280 * 28 * 28,
|
|
|
|
|
|
"fps": 1,
|
|
|
|
|
|
},
|
|
|
|
|
|
},
|
2025-12-27 18:42:46 +08:00
|
|
|
|
"hunyuan-vl": {
|
|
|
|
|
|
"model": "Tencent-Hunyuan/HunyuanOCR",
|
|
|
|
|
|
"prompt_fn": hunyuan_prompt,
|
|
|
|
|
|
"mm_processor_kwargs": {},
|
|
|
|
|
|
},
|
2025-09-02 09:02:22 +08:00
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
2025-12-23 10:46:54 +08:00
|
|
|
|
@pytest.fixture(params=PROMPT_CONFIGS.keys())
|
|
|
|
|
|
def vl_config(request):
|
2026-02-05 19:31:17 +08:00
|
|
|
|
config = PROMPT_CONFIGS[request.param]
|
|
|
|
|
|
if "skip" in config:
|
|
|
|
|
|
pytest.skip(config["skip"])
|
|
|
|
|
|
return config
|