[CI]Fix broken CI (#1773)
This PR fixed the broken CI. It require
https://github.com/vllm-project/vllm/pull/20900 merged first.
- vLLM version: v0.9.2
- vLLM main:
e8cc53af5e
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
@@ -47,6 +47,7 @@ from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
|
|||||||
from vllm_ascend.utils import adapt_patch # noqa E402
|
from vllm_ascend.utils import adapt_patch # noqa E402
|
||||||
|
|
||||||
adapt_patch(True)
|
adapt_patch(True)
|
||||||
|
adapt_patch(False)
|
||||||
|
|
||||||
from vllm.distributed.parallel_state import ( # noqa E402
|
from vllm.distributed.parallel_state import ( # noqa E402
|
||||||
destroy_distributed_environment, destroy_model_parallel)
|
destroy_distributed_environment, destroy_model_parallel)
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
from vllm_ascend.core.scheduler import AscendScheduler
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
@@ -303,6 +304,8 @@ def test_stop_via_update_from_output():
|
|||||||
req.num_computed_tokens = req.num_tokens
|
req.num_computed_tokens = req.num_tokens
|
||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -355,6 +358,8 @@ def test_stop_via_update_from_output():
|
|||||||
req.num_computed_tokens = req.num_tokens
|
req.num_computed_tokens = req.num_tokens
|
||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -405,6 +410,8 @@ def test_stop_via_update_from_output():
|
|||||||
req.num_computed_tokens = req.num_tokens
|
req.num_computed_tokens = req.num_tokens
|
||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
|
|||||||
@@ -1,259 +0,0 @@
|
|||||||
#
|
|
||||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
|
||||||
# Copyright 2023 The vLLM team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# This file is a part of the vllm-ascend project.
|
|
||||||
# Adapted from vllm/tests/entrypoints/openai/test_completion_with_prompt_embeds.py
|
|
||||||
#
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
import os
|
|
||||||
|
|
||||||
import openai # use the official client for correctness check
|
|
||||||
import pytest
|
|
||||||
import pytest_asyncio
|
|
||||||
import torch
|
|
||||||
from modelscope import snapshot_download # type: ignore
|
|
||||||
from openai import BadRequestError
|
|
||||||
from transformers import AutoConfig
|
|
||||||
from vllm.engine.arg_utils import EngineArgs
|
|
||||||
|
|
||||||
from tests.utils import RemoteOpenAIServer
|
|
||||||
|
|
||||||
if not hasattr(EngineArgs, "enable_prompt_embeds"):
|
|
||||||
pytest.skip("Not supported vllm version", allow_module_level=True)
|
|
||||||
|
|
||||||
# any model with a chat template should work here
|
|
||||||
MODEL_NAME = snapshot_download("LLM-Research/Llama-3.2-1B-Instruct")
|
|
||||||
|
|
||||||
CONFIG = AutoConfig.from_pretrained(MODEL_NAME)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module")
|
|
||||||
def default_server_args() -> list[str]:
|
|
||||||
return [
|
|
||||||
# use half precision for speed and memory savings in CI environment
|
|
||||||
"--dtype",
|
|
||||||
"bfloat16",
|
|
||||||
"--max-model-len",
|
|
||||||
"8192",
|
|
||||||
"--max-num-seqs",
|
|
||||||
"128",
|
|
||||||
"--enforce-eager",
|
|
||||||
# Prompt Embeds server args
|
|
||||||
"--enable-prompt-embeds",
|
|
||||||
"--no-enable-chunked-prefill",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="module",
|
|
||||||
params=["", "--disable-frontend-multiprocessing"])
|
|
||||||
def server_with_prompt_embeds(default_server_args, request):
|
|
||||||
if request.param:
|
|
||||||
default_server_args.append(request.param)
|
|
||||||
|
|
||||||
with RemoteOpenAIServer(MODEL_NAME, default_server_args) as remote_server:
|
|
||||||
yield remote_server
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def client_with_prompt_embeds(server_with_prompt_embeds):
|
|
||||||
async with server_with_prompt_embeds.get_async_client() as async_client:
|
|
||||||
yield async_client
|
|
||||||
|
|
||||||
|
|
||||||
def create_dummy_embeds(num_tokens: int = 5) -> str:
|
|
||||||
"""Create dummy embeddings and return them as base64 encoded string."""
|
|
||||||
dummy_embeds = torch.randn(num_tokens, CONFIG.hidden_size)
|
|
||||||
buffer = io.BytesIO()
|
|
||||||
torch.save(dummy_embeds, buffer)
|
|
||||||
return base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.getenv("VLLM_USE_V1") == "1",
|
|
||||||
reason="Enable embedding input will fallback to v0, skip it")
|
|
||||||
async def test_completions_with_prompt_embeds(
|
|
||||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
|
||||||
# Test case: Single prompt embeds input
|
|
||||||
encoded_embeds = create_dummy_embeds()
|
|
||||||
completion = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
assert len(completion.choices[0].text) >= 1
|
|
||||||
assert completion.choices[0].prompt_logprobs is None
|
|
||||||
|
|
||||||
# Test case: batch completion with prompt_embeds
|
|
||||||
encoded_embeds2 = create_dummy_embeds()
|
|
||||||
completion = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
||||||
assert len(completion.choices) == 2
|
|
||||||
assert len(completion.choices[0].text) >= 1
|
|
||||||
assert len(completion.choices[1].text) >= 1
|
|
||||||
|
|
||||||
# Test case: streaming with prompt_embeds
|
|
||||||
encoded_embeds = create_dummy_embeds()
|
|
||||||
single_completion = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
single_output = single_completion.choices[0].text
|
|
||||||
|
|
||||||
stream = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
stream=True,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
chunks = []
|
|
||||||
finish_reason_count = 0
|
|
||||||
async for chunk in stream:
|
|
||||||
chunks.append(chunk.choices[0].text)
|
|
||||||
if chunk.choices[0].finish_reason is not None:
|
|
||||||
finish_reason_count += 1
|
|
||||||
assert finish_reason_count == 1
|
|
||||||
assert chunk.choices[0].finish_reason == "length"
|
|
||||||
assert chunk.choices[0].text
|
|
||||||
assert "".join(chunks) == single_output
|
|
||||||
|
|
||||||
# Test case: batch streaming with prompt_embeds
|
|
||||||
encoded_embeds2 = create_dummy_embeds()
|
|
||||||
stream = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
stream=True,
|
|
||||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
||||||
chunks_stream_embeds: list[list[str]] = [[], []]
|
|
||||||
finish_reason_count = 0
|
|
||||||
async for chunk in stream:
|
|
||||||
chunks_stream_embeds[chunk.choices[0].index].append(
|
|
||||||
chunk.choices[0].text)
|
|
||||||
if chunk.choices[0].finish_reason is not None:
|
|
||||||
finish_reason_count += 1
|
|
||||||
assert finish_reason_count == 2
|
|
||||||
assert chunk.choices[0].finish_reason == "length"
|
|
||||||
assert chunk.choices[0].text
|
|
||||||
assert len(chunks_stream_embeds[0]) > 0
|
|
||||||
assert len(chunks_stream_embeds[1]) > 0
|
|
||||||
|
|
||||||
# Test case: mixed text and prompt_embeds
|
|
||||||
encoded_embeds = create_dummy_embeds()
|
|
||||||
completion_mixed = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="This is a prompt",
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
assert len(completion.choices) == 2
|
|
||||||
completion_text_only = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="This is a prompt",
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
)
|
|
||||||
completion_embeds_only = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="",
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
# Embeddings responses should be handled first
|
|
||||||
assert completion_mixed.choices[0].text == completion_embeds_only.choices[
|
|
||||||
0].text
|
|
||||||
assert completion_mixed.choices[1].text == completion_text_only.choices[
|
|
||||||
0].text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.getenv("VLLM_USE_V1") == "1",
|
|
||||||
reason="Enable embedding input will fallback to v0, skip it")
|
|
||||||
async def test_completions_errors_with_prompt_embeds(
|
|
||||||
client_with_prompt_embeds: openai.AsyncOpenAI, model_name: str):
|
|
||||||
# Test error case: invalid prompt_embeds
|
|
||||||
with pytest.raises(BadRequestError):
|
|
||||||
await client_with_prompt_embeds.completions.create(
|
|
||||||
prompt="",
|
|
||||||
model=model_name,
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
extra_body={"prompt_embeds": "invalid_base64"})
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize("logprobs_arg", [1, 0])
|
|
||||||
@pytest.mark.parametrize("model_name", [MODEL_NAME])
|
|
||||||
@pytest.mark.skipif(
|
|
||||||
os.getenv("VLLM_USE_V1") == "1",
|
|
||||||
reason="Enable embedding input will fallback to v0, skip it")
|
|
||||||
async def test_completions_with_logprobs_and_prompt_embeds(
|
|
||||||
client_with_prompt_embeds: openai.AsyncOpenAI, logprobs_arg: int,
|
|
||||||
model_name: str):
|
|
||||||
# Test case: Logprobs using prompt_embeds
|
|
||||||
encoded_embeds = create_dummy_embeds()
|
|
||||||
completion = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
echo=False,
|
|
||||||
logprobs=logprobs_arg,
|
|
||||||
extra_body={"prompt_embeds": encoded_embeds})
|
|
||||||
|
|
||||||
logprobs = completion.choices[0].logprobs
|
|
||||||
assert logprobs is not None
|
|
||||||
assert len(logprobs.text_offset) == 5
|
|
||||||
assert len(logprobs.token_logprobs) == 5
|
|
||||||
assert len(logprobs.top_logprobs) == 5
|
|
||||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
||||||
assert max(logprobs_arg, 1) <= len(top_logprobs) <= logprobs_arg + 1
|
|
||||||
assert len(logprobs.tokens) == 5
|
|
||||||
|
|
||||||
# Test case: Log probs with batch completion and prompt_embeds
|
|
||||||
encoded_embeds2 = create_dummy_embeds()
|
|
||||||
completion = await client_with_prompt_embeds.completions.create(
|
|
||||||
model=model_name,
|
|
||||||
prompt="", # Add empty prompt as required parameter
|
|
||||||
max_tokens=5,
|
|
||||||
temperature=0.0,
|
|
||||||
echo=False,
|
|
||||||
logprobs=logprobs_arg,
|
|
||||||
extra_body={"prompt_embeds": [encoded_embeds, encoded_embeds2]})
|
|
||||||
|
|
||||||
assert len(completion.choices) == 2
|
|
||||||
for choice in completion.choices:
|
|
||||||
logprobs = choice.logprobs
|
|
||||||
assert logprobs is not None
|
|
||||||
assert len(logprobs.text_offset) == 5
|
|
||||||
assert len(logprobs.token_logprobs) == 5
|
|
||||||
assert len(logprobs.top_logprobs) == 5
|
|
||||||
for top_logprobs in logprobs.top_logprobs[1:]:
|
|
||||||
assert max(logprobs_arg,
|
|
||||||
1) <= len(top_logprobs) <= logprobs_arg + 1
|
|
||||||
assert len(logprobs.tokens) == 5
|
|
||||||
@@ -31,6 +31,7 @@ from vllm.v1.request import Request, RequestStatus
|
|||||||
from vllm.v1.structured_output import StructuredOutputManager
|
from vllm.v1.structured_output import StructuredOutputManager
|
||||||
|
|
||||||
from vllm_ascend.core.scheduler import AscendScheduler
|
from vllm_ascend.core.scheduler import AscendScheduler
|
||||||
|
from vllm_ascend.utils import vllm_version_is
|
||||||
|
|
||||||
EOS_TOKEN_ID = 50256
|
EOS_TOKEN_ID = 50256
|
||||||
|
|
||||||
@@ -213,6 +214,8 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
scheduler.scheduled_req_ids.add(req.request_id)
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -263,6 +266,8 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
scheduler.scheduled_req_ids.add(req.request_id)
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
@@ -311,6 +316,8 @@ def test_stop_via_update_from_output():
|
|||||||
scheduler.requests[req.request_id] = req
|
scheduler.requests[req.request_id] = req
|
||||||
scheduler.running.append(req)
|
scheduler.running.append(req)
|
||||||
scheduler.scheduled_req_ids.add(req.request_id)
|
scheduler.scheduled_req_ids.add(req.request_id)
|
||||||
|
if not vllm_version_is("0.9.2"):
|
||||||
|
req.status = RequestStatus.RUNNING
|
||||||
|
|
||||||
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
|
||||||
scheduled_cached_reqs=[],
|
scheduled_cached_reqs=[],
|
||||||
|
|||||||
132
tests/utils.py
132
tests/utils.py
@@ -20,146 +20,16 @@
|
|||||||
import functools
|
import functools
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from typing import Callable, Optional
|
from typing import Callable
|
||||||
|
|
||||||
import openai
|
|
||||||
import requests
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from typing_extensions import ParamSpec
|
from typing_extensions import ParamSpec
|
||||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
|
||||||
from vllm.entrypoints.openai.cli_args import make_arg_parser
|
|
||||||
from vllm.model_executor.model_loader import get_model_loader
|
|
||||||
from vllm.utils import FlexibleArgumentParser, get_open_port
|
|
||||||
|
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
|
|
||||||
|
|
||||||
class RemoteOpenAIServer:
|
|
||||||
DUMMY_API_KEY = "token-abc123" # vLLM's OpenAI server does not need API key
|
|
||||||
|
|
||||||
def __init__(self,
|
|
||||||
model: str,
|
|
||||||
vllm_serve_args: list[str],
|
|
||||||
*,
|
|
||||||
env_dict: Optional[dict[str, str]] = None,
|
|
||||||
seed: Optional[int] = 0,
|
|
||||||
auto_port: bool = True,
|
|
||||||
max_wait_seconds: Optional[float] = None) -> None:
|
|
||||||
if auto_port:
|
|
||||||
if "-p" in vllm_serve_args or "--port" in vllm_serve_args:
|
|
||||||
raise ValueError("You have manually specified the port "
|
|
||||||
"when `auto_port=True`.")
|
|
||||||
|
|
||||||
# Don't mutate the input args
|
|
||||||
vllm_serve_args = vllm_serve_args + [
|
|
||||||
"--port", str(get_open_port())
|
|
||||||
]
|
|
||||||
if seed is not None:
|
|
||||||
if "--seed" in vllm_serve_args:
|
|
||||||
raise ValueError("You have manually specified the seed "
|
|
||||||
f"when `seed={seed}`.")
|
|
||||||
|
|
||||||
vllm_serve_args = vllm_serve_args + ["--seed", str(seed)]
|
|
||||||
|
|
||||||
parser = FlexibleArgumentParser(
|
|
||||||
description="vLLM's remote OpenAI server.")
|
|
||||||
parser = make_arg_parser(parser)
|
|
||||||
args = parser.parse_args(["--model", model, *vllm_serve_args])
|
|
||||||
self.host = str(args.host or 'localhost')
|
|
||||||
self.port = int(args.port)
|
|
||||||
|
|
||||||
self.show_hidden_metrics = \
|
|
||||||
args.show_hidden_metrics_for_version is not None
|
|
||||||
|
|
||||||
# download the model before starting the server to avoid timeout
|
|
||||||
is_local = os.path.isdir(model)
|
|
||||||
if not is_local:
|
|
||||||
engine_args = AsyncEngineArgs.from_cli_args(args)
|
|
||||||
model_config = engine_args.create_model_config()
|
|
||||||
load_config = engine_args.create_load_config()
|
|
||||||
|
|
||||||
model_loader = get_model_loader(load_config)
|
|
||||||
model_loader.download_model(model_config)
|
|
||||||
|
|
||||||
env = os.environ.copy()
|
|
||||||
# the current process might initialize cuda,
|
|
||||||
# to be safe, we should use spawn method
|
|
||||||
env['VLLM_WORKER_MULTIPROC_METHOD'] = 'spawn'
|
|
||||||
if env_dict is not None:
|
|
||||||
env.update(env_dict)
|
|
||||||
self.proc = subprocess.Popen(
|
|
||||||
["vllm", "serve", model, *vllm_serve_args],
|
|
||||||
env=env,
|
|
||||||
stdout=sys.stdout,
|
|
||||||
stderr=sys.stderr,
|
|
||||||
)
|
|
||||||
max_wait_seconds = max_wait_seconds or 240
|
|
||||||
self._wait_for_server(url=self.url_for("health"),
|
|
||||||
timeout=max_wait_seconds)
|
|
||||||
|
|
||||||
def __enter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_value, traceback):
|
|
||||||
self.proc.terminate()
|
|
||||||
try:
|
|
||||||
self.proc.wait(8)
|
|
||||||
except subprocess.TimeoutExpired:
|
|
||||||
# force kill if needed
|
|
||||||
self.proc.kill()
|
|
||||||
|
|
||||||
def _wait_for_server(self, *, url: str, timeout: float):
|
|
||||||
# run health check
|
|
||||||
start = time.time()
|
|
||||||
while True:
|
|
||||||
try:
|
|
||||||
if requests.get(url).status_code == 200:
|
|
||||||
break
|
|
||||||
except Exception:
|
|
||||||
# this exception can only be raised by requests.get,
|
|
||||||
# which means the server is not ready yet.
|
|
||||||
# the stack trace is not useful, so we suppress it
|
|
||||||
# by using `raise from None`.
|
|
||||||
result = self.proc.poll()
|
|
||||||
if result is not None and result != 0:
|
|
||||||
raise RuntimeError("Server exited unexpectedly.") from None
|
|
||||||
|
|
||||||
time.sleep(0.5)
|
|
||||||
if time.time() - start > timeout:
|
|
||||||
raise RuntimeError(
|
|
||||||
"Server failed to start in time.") from None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def url_root(self) -> str:
|
|
||||||
return f"http://{self.host}:{self.port}"
|
|
||||||
|
|
||||||
def url_for(self, *parts: str) -> str:
|
|
||||||
return self.url_root + "/" + "/".join(parts)
|
|
||||||
|
|
||||||
def get_client(self, **kwargs):
|
|
||||||
if "timeout" not in kwargs:
|
|
||||||
kwargs["timeout"] = 600
|
|
||||||
return openai.OpenAI(
|
|
||||||
base_url=self.url_for("v1"),
|
|
||||||
api_key=self.DUMMY_API_KEY,
|
|
||||||
max_retries=0,
|
|
||||||
**kwargs,
|
|
||||||
)
|
|
||||||
|
|
||||||
def get_async_client(self, **kwargs):
|
|
||||||
if "timeout" not in kwargs:
|
|
||||||
kwargs["timeout"] = 600
|
|
||||||
return openai.AsyncOpenAI(base_url=self.url_for("v1"),
|
|
||||||
api_key=self.DUMMY_API_KEY,
|
|
||||||
max_retries=0,
|
|
||||||
**kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def fork_new_process_for_each_test(
|
def fork_new_process_for_each_test(
|
||||||
f: Callable[_P, None]) -> Callable[_P, None]:
|
f: Callable[_P, None]) -> Callable[_P, None]:
|
||||||
"""Decorator to fork a new process for each test function.
|
"""Decorator to fork a new process for each test function.
|
||||||
|
|||||||
@@ -63,7 +63,6 @@ from vllm.v1.sample.metadata import SamplingMetadata
|
|||||||
from vllm.v1.sample.sampler import Sampler
|
from vllm.v1.sample.sampler import Sampler
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||||
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
from vllm.v1.spec_decode.ngram_proposer import NgramProposer
|
||||||
from vllm.v1.utils import bind_kv_cache
|
|
||||||
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
from vllm.v1.worker.lora_model_runner_mixin import LoRAModelRunnerMixin
|
||||||
from vllm.v1.worker.utils import (gather_mm_placeholders,
|
from vllm.v1.worker.utils import (gather_mm_placeholders,
|
||||||
sanity_check_mm_encoder_outputs,
|
sanity_check_mm_encoder_outputs,
|
||||||
@@ -83,11 +82,16 @@ from vllm_ascend.utils import (ACL_FORMAT_FRACTAL_ND, ACL_FORMAT_FRACTAL_NZ,
|
|||||||
ProfileExecuteDuration,
|
ProfileExecuteDuration,
|
||||||
check_torchair_cache_exist, is_310p,
|
check_torchair_cache_exist, is_310p,
|
||||||
maybe_converting_weight_acl_format,
|
maybe_converting_weight_acl_format,
|
||||||
write_kv_cache_bytes_to_file)
|
vllm_version_is, write_kv_cache_bytes_to_file)
|
||||||
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
from vllm_ascend.worker.eagle_proposer_v1 import EagleProposer
|
||||||
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
from vllm_ascend.worker.mtp_proposer_v1 import MtpProposer
|
||||||
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
from vllm_ascend.worker.npu_input_batch import CachedRequestState, InputBatch
|
||||||
|
|
||||||
|
if vllm_version_is("0.9.2"):
|
||||||
|
from vllm.v1.utils import bind_kv_cache
|
||||||
|
else:
|
||||||
|
from vllm.v1.worker.utils import bind_kv_cache
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
import xgrammar as xgr # type: ignore[import-untyped]
|
import xgrammar as xgr # type: ignore[import-untyped]
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
|
|||||||
Reference in New Issue
Block a user