[Test] Remove VLLM_USE_V1 in example and tests (#1733)

V1 is enabled by default, no need to set it by hand now. This PR remove
the useless setting in example and tests

- vLLM version: v0.9.2
- vLLM main:
9ad0a4588b

Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
This commit is contained in:
wangxiyuan
2025-07-15 12:49:57 +08:00
committed by GitHub
parent eb921d2b6f
commit 787010a637
29 changed files with 186 additions and 291 deletions

519
tests/e2e/conftest.py Normal file
View File

@@ -0,0 +1,519 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
#
import contextlib
import gc
import os
from typing import Any, List, Optional, Tuple, TypeVar, Union
import numpy as np
import pytest
import torch
from modelscope import snapshot_download # type: ignore[import-untyped]
from PIL import Image
from torch import nn
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
BatchEncoding, BatchFeature)
from transformers.models.auto.auto_factory import _BaseAutoModelClass
from vllm import LLM, SamplingParams
from vllm.config import TaskOption, _get_and_verify_dtype
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
from vllm.outputs import RequestOutput
from vllm.sampling_params import BeamSearchParams
from vllm.transformers_utils.utils import maybe_model_redirect
from vllm.utils import is_list_of
from tests.e2e.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs)
# TODO: remove this part after the patch merged into vllm, if
# we not explicitly patch here, some of them might be effectiveless
# in pytest scenario
from vllm_ascend.utils import adapt_patch # noqa E402
adapt_patch(True)
adapt_patch(False)
from vllm.distributed.parallel_state import ( # noqa E402
destroy_distributed_environment, destroy_model_parallel)
_T = TypeVar("_T", nn.Module, torch.Tensor, BatchEncoding, BatchFeature, dict)
_M = TypeVar("_M")
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
PromptImageInput = _PromptMultiModalInput[Image.Image]
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
_TEST_DIR = os.path.dirname(__file__)
_TEST_PROMPTS = [os.path.join(_TEST_DIR, "prompts", "example.txt")]
def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
destroy_model_parallel()
destroy_distributed_environment()
with contextlib.suppress(AssertionError):
torch.distributed.destroy_process_group()
if shutdown_ray:
import ray # Lazy import Ray
ray.shutdown()
gc.collect()
torch.npu.empty_cache()
torch.npu.reset_peak_memory_stats()
class VllmRunner:
def __init__(
self,
model_name: str,
task: TaskOption = "auto",
tokenizer_name: Optional[str] = None,
tokenizer_mode: str = "auto",
# Use smaller max model length, otherwise bigger model cannot run due
# to kv cache size limit.
max_model_len: int = 1024,
dtype: str = "half",
disable_log_stats: bool = True,
tensor_parallel_size: int = 1,
block_size: int = 16,
enable_chunked_prefill: bool = False,
swap_space: int = 4,
enforce_eager: Optional[bool] = True,
quantization: Optional[str] = None,
**kwargs,
) -> None:
self.model = LLM(
model=model_name,
task=task,
tokenizer=tokenizer_name,
tokenizer_mode=tokenizer_mode,
trust_remote_code=True,
dtype=dtype,
swap_space=swap_space,
enforce_eager=enforce_eager,
disable_log_stats=disable_log_stats,
tensor_parallel_size=tensor_parallel_size,
max_model_len=max_model_len,
block_size=block_size,
enable_chunked_prefill=enable_chunked_prefill,
quantization=quantization,
**kwargs,
)
def get_inputs(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[TextPrompt]:
if images is not None:
assert len(prompts) == len(images)
if videos is not None:
assert len(prompts) == len(videos)
if audios is not None:
assert len(prompts) == len(audios)
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
if images is not None:
for i, image in enumerate(images):
if image is not None:
inputs[i]["multi_modal_data"] = {"image": image}
if videos is not None:
for i, video in enumerate(videos):
if video is not None:
inputs[i]["multi_modal_data"] = {"video": video}
if audios is not None:
for i, audio in enumerate(audios):
if audio is not None:
inputs[i]["multi_modal_data"] = {"audio": audio}
return inputs
def generate(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[List[int]], List[str]]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
outputs: List[Tuple[List[List[int]], List[str]]] = []
for req_output in req_outputs:
prompt_str = req_output.prompt
prompt_ids = req_output.prompt_token_ids
req_sample_output_ids: List[List[int]] = []
req_sample_output_strs: List[str] = []
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
req_sample_output_ids.append(prompt_ids + output_ids)
req_sample_output_strs.append(prompt_str + output_str)
outputs.append((req_sample_output_ids, req_sample_output_strs))
return outputs
@staticmethod
def _final_steps_generate_w_logprobs(
req_outputs: List[RequestOutput],
) -> List[TokensTextLogprobsPromptLogprobs]:
outputs: List[TokensTextLogprobsPromptLogprobs] = []
for req_output in req_outputs:
assert len(req_output.outputs) > 0
for sample in req_output.outputs:
output_str = sample.text
output_ids = list(sample.token_ids)
output_logprobs = sample.logprobs
outputs.append((output_ids, output_str, output_logprobs,
req_output.prompt_logprobs))
return outputs
def generate_w_logprobs(
self,
prompts: List[str],
sampling_params: SamplingParams,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.generate(inputs,
sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_encoder_decoder_w_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
sampling_params: SamplingParams,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
'''
Logprobs generation for vLLM encoder/decoder models
'''
assert sampling_params.logprobs is not None
req_outputs = self.model.generate(encoder_decoder_prompts,
sampling_params=sampling_params)
toks_str_logsprobs_prompt_logprobs = (
self._final_steps_generate_w_logprobs(req_outputs))
# Omit prompt logprobs if not required by sampling params
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
if sampling_params.prompt_logprobs is None else
toks_str_logsprobs_prompt_logprobs)
def generate_greedy(
self,
prompts: List[str],
max_tokens: int,
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[Tuple[List[int], str]]:
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
outputs = self.generate(prompts,
greedy_params,
images=images,
videos=videos,
audios=audios)
return [(output_ids[0], output_str[0])
for output_ids, output_str in outputs]
def generate_greedy_logprobs(
self,
prompts: List[str],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
images: Optional[PromptImageInput] = None,
audios: Optional[PromptAudioInput] = None,
videos: Optional[PromptVideoInput] = None,
stop_token_ids: Optional[List[int]] = None,
stop: Optional[List[str]] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=num_prompt_logprobs,
stop_token_ids=stop_token_ids,
stop=stop)
return self.generate_w_logprobs(prompts,
greedy_logprobs_params,
images=images,
audios=audios,
videos=videos)
def generate_encoder_decoder_greedy_logprobs(
self,
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
max_tokens: int,
num_logprobs: int,
num_prompt_logprobs: Optional[int] = None,
) -> Union[List[TokensTextLogprobs],
List[TokensTextLogprobsPromptLogprobs]]:
greedy_logprobs_params = SamplingParams(
temperature=0.0,
max_tokens=max_tokens,
logprobs=num_logprobs,
prompt_logprobs=(num_prompt_logprobs),
)
'''
Greedy logprobs generation for vLLM encoder/decoder models
'''
return self.generate_encoder_decoder_w_logprobs(
encoder_decoder_prompts, greedy_logprobs_params)
def generate_beam_search(
self,
prompts: Union[List[str], List[List[int]]],
beam_width: int,
max_tokens: int,
) -> List[Tuple[List[List[int]], List[str]]]:
if is_list_of(prompts, str, check="all"):
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
else:
prompts = [
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
]
outputs = self.model.beam_search(
prompts,
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
returned_outputs = []
for output in outputs:
token_ids = [x.tokens for x in output.sequences]
texts = [x.text for x in output.sequences]
returned_outputs.append((token_ids, texts))
return returned_outputs
def classify(self, prompts: List[str]) -> List[List[float]]:
req_outputs = self.model.classify(prompts)
return [req_output.outputs.probs for req_output in req_outputs]
def encode(
self,
prompts: List[str],
images: Optional[PromptImageInput] = None,
videos: Optional[PromptVideoInput] = None,
audios: Optional[PromptAudioInput] = None,
) -> List[List[float]]:
inputs = self.get_inputs(prompts,
images=images,
videos=videos,
audios=audios)
req_outputs = self.model.embed(inputs)
return [req_output.outputs.embedding for req_output in req_outputs]
def score(
self,
text_1: Union[str, List[str]],
text_2: Union[str, List[str]],
) -> List[float]:
req_outputs = self.model.score(text_1, text_2)
return [req_output.outputs.score for req_output in req_outputs]
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def vllm_runner():
return VllmRunner
@pytest.fixture(params=list(PROMPT_TEMPLATES.keys()))
def prompt_template(request):
return PROMPT_TEMPLATES[request.param]
def _read_prompts(filename: str) -> list[str]:
with open(filename) as f:
prompts = f.readlines()
return prompts
@pytest.fixture
def example_prompts() -> list[str]:
prompts = []
for filename in _TEST_PROMPTS:
prompts += _read_prompts(filename)
return prompts
@pytest.fixture(scope="session")
def ilama_lora_files():
return snapshot_download(repo_id="vllm-ascend/ilama-text2sql-spider")
class HfRunner:
def get_default_device(self):
from vllm.platforms import current_platform
return ("cpu"
if current_platform.is_cpu() else current_platform.device_type)
def wrap_device(self, x: _T, device: Optional[str] = None) -> _T:
if x is None or isinstance(x, (bool, )):
return x
if device is None:
device = self.device
if isinstance(x, dict):
return {k: self.wrap_device(v, device) for k, v in x.items()}
if hasattr(x, "device") and x.device.type == device:
return x
return x.to(device)
def __init__(
self,
model_name: str,
dtype: str = "auto",
*,
model_kwargs: Optional[dict[str, Any]] = None,
trust_remote_code: bool = True,
is_sentence_transformer: bool = False,
is_cross_encoder: bool = False,
skip_tokenizer_init: bool = False,
auto_cls: type[_BaseAutoModelClass] = AutoModelForCausalLM,
) -> None:
model_name = maybe_model_redirect(model_name)
self.model_name = model_name
self.config = AutoConfig.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
)
self.device = self.get_default_device()
self.dtype = torch_dtype = _get_and_verify_dtype(
self.model_name,
self.config,
dtype=dtype,
is_pooling_model=is_sentence_transformer or is_cross_encoder,
)
model_kwargs = model_kwargs if model_kwargs is not None else {}
model_kwargs.setdefault("torch_dtype", torch_dtype)
if is_sentence_transformer:
# Lazy init required for AMD CI
from sentence_transformers import SentenceTransformer
self.model = SentenceTransformer(
model_name,
device=self.device,
model_kwargs=model_kwargs,
trust_remote_code=trust_remote_code,
)
elif is_cross_encoder:
# Lazy init required for AMD CI
from sentence_transformers import CrossEncoder
self.model = CrossEncoder(
model_name,
device=self.device,
automodel_args=model_kwargs,
trust_remote_code=trust_remote_code,
)
else:
model = auto_cls.from_pretrained(
model_name,
trust_remote_code=trust_remote_code,
**model_kwargs,
)
# in case some unquantized custom models are not in same dtype
if (getattr(model, "quantization_method", None) is None
and any(p.dtype != self.dtype
for p in model.parameters())):
model = model.to(dtype=self.dtype)
if (getattr(model, "quantization_method", None) != "bitsandbytes"
and len({p.device
for p in model.parameters()}) < 2):
model = model.to(device=self.device)
self.model = model
if not skip_tokenizer_init:
self.tokenizer = AutoTokenizer.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
# don't put this import at the top level
# it will call torch.cuda.device_count()
from transformers import AutoProcessor # noqa: F401
self.processor = AutoProcessor.from_pretrained(
model_name,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
)
if skip_tokenizer_init:
self.tokenizer = self.processor.tokenizer
def encode(self, prompts: list[str], *args,
**kwargs) -> list[list[torch.Tensor]]:
return self.model.encode(prompts, *args, **kwargs)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
del self.model
cleanup_dist_env_and_memory()
@pytest.fixture(scope="session")
def hf_runner():
return HfRunner

274
tests/e2e/model_utils.py Normal file
View File

@@ -0,0 +1,274 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
#
import warnings
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
TokensText = Tuple[List[int], str]
def check_outputs_equal(
*,
outputs_0_lst: Sequence[TokensText],
outputs_1_lst: Sequence[TokensText],
name_0: str,
name_1: str,
):
"""
Compare the two sequences generated by different models,
which should be equal.
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
output_ids_0, output_str_0 = outputs_0
output_ids_1, output_str_1 = outputs_1
# The text and token outputs should exactly match
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
assert output_str_0 == output_str_1, fail_msg
assert output_ids_0 == output_ids_1, fail_msg
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * List of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
float]],
SampleLogprobs]]]
# Allow for tokens to be represented as str's rather than IDs;
# tuple of
# * Token string representations list
# * String
# * Optional list of top sample logprobs for each sampled token
#
# Assumes prompt logprobs were not requested.
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
List[Dict[str,
Logprob]]]]]
# Representation of generated sequence as a tuple of
# * Token ID list
# * String
# * Optional list of top sample logprobs for each sampled token
# * Optional list of top prompt logprobs for each prompt token
#
# Allows prompt logprobs to be requested.
TokensTextLogprobsPromptLogprobs = Tuple[
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
def check_logprobs_close(
*,
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
TokensTextLogprobsPromptLogprobs,
TextTextLogprobs]],
name_0: str,
name_1: str,
num_outputs_0_skip_tokens: int = 0,
warn_on_mismatch: bool = True,
always_check_logprobs: bool = False,
) -> None:
"""Compare the logprobs of two sequences generated by different models,
which should be similar but not necessarily equal.
How sample logprobs are compared:
* `always_check_logprobs == True`: set of highest-logprob token ids
must match between seq0 and seq1 at all sampled token offsets
* `always_check_logprobs == False`: highest-logprob token ids are
only compared at sampled token offsets for which generated token
ids don't match
Prompt logprobs must be provided either for both input sequences, or
for neither. If prompt logprobs are provided, then highest-logprob
prompt token ids must match between seq0 and seq1 at all prompt token
offsets.
Args:
outputs_0_lst: First sequence to compare
outputs_0_lst: Second sequence to compare
name_0: sequence #0 name
name_1: sequence #1 name
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
sequence #0 tokens & logprobs to discard
before comparison, i.e. all
of sequence #1 will be compared to
sequence #0 beginning at index
num_outputs_0_skip_tokens
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
mismatch between the two sequences
always_check_logprobs: If true, check logprobs even when tokens match
"""
assert len(outputs_0_lst) == len(outputs_1_lst)
# Loop through responses to each prompt.
for prompt_idx, (outputs_0,
outputs_1) in enumerate(zip(outputs_0_lst,
outputs_1_lst)):
assert len(outputs_0) == len(outputs_1)
if len(outputs_0) == 3:
assert len(outputs_1) == 3
# Break out tokens, text & sample logprobs
# (prompt logprobs were not provided)
output_ids_0, output_str_0, logprobs_0 = outputs_0
output_ids_1, output_str_1, logprobs_1 = outputs_1
elif len(outputs_0) == 4:
assert len(outputs_1) == 4
# Break out tokens, text, sample logprobs & prompt logprobs
(
output_ids_0,
output_str_0,
logprobs_0,
prompt_logprobs_0,
) = outputs_0
(
output_ids_1,
output_str_1,
logprobs_1,
prompt_logprobs_1,
) = outputs_1
# Test prompt logprobs closeness
if (prompt_logprobs_0 is not None
and prompt_logprobs_1 is not None):
# Both sequences' prompt logprobs lists are not `None``
# (although individual list elements may be `None`);
# for each token's logprobs:
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
zip(prompt_logprobs_0, prompt_logprobs_1)):
fail_msg = (
f"Prompt logprobs test:"
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
if logprobs_elem_0 is None:
# If the seq 0 token's logprobs are `None`,
# the seq 1 token's logprobs must be `None`
assert logprobs_elem_1 is None, fail_msg
else:
# If the seq 0 token's logprobs are not `None`,
# the seq 1 token's logprobs must not be `None`
assert logprobs_elem_1 is not None, fail_msg
# Logprobs check: top-k token choices must be the same
assert (set(logprobs_elem_0.keys()) == set(
logprobs_elem_1.keys())), fail_msg
else:
# Both sequence logprobs lists must be `None`
fail_msg = (f"Prompt logprobs test:"
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
assert (prompt_logprobs_0 is None
and prompt_logprobs_1 is None), fail_msg
else:
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
f"{len(outputs_0)} elements were provided: "
f"{outputs_0}")
if logprobs_0 is None:
logprobs_0 = [None] * len(output_ids_0)
if logprobs_1 is None:
logprobs_1 = [None] * len(output_ids_1)
# Skip specified number of initial sequence #0 tokens
# & logprobs, leaving output text as-is for simplicity
# (text mismatches may generate warnings but do not
# cause the test to fail.)
if num_outputs_0_skip_tokens < 0:
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
# Loop through generated tokens.
for idx, (output_id_0,
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
is_tok_mismatch = output_id_0 != output_id_1
# If generated tokens don't match
# or it is desired to always check logprobs,
# then
if is_tok_mismatch or always_check_logprobs:
logprobs_elem_0 = logprobs_0[idx]
logprobs_elem_1 = logprobs_1[idx]
# Each predicted token must be in top N logprobs of the other
fail_msg = (
f"Test{prompt_idx}:"
f"\nMatched tokens:\t{output_ids_0[:idx]}"
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
assert logprobs_elem_0 is not None, fail_msg
assert logprobs_elem_1 is not None, fail_msg
assert output_id_0 in logprobs_elem_1, fail_msg
assert output_id_1 in logprobs_elem_0, fail_msg
if warn_on_mismatch and is_tok_mismatch:
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
# Break out since sequences will now diverge.
break
else:
if output_str_0 != output_str_1 and warn_on_mismatch:
# The token outputs exactly match,
# so the text outputs should exactly match as well
fail_msg = (f"Test{prompt_idx}:"
f"\n{name_0}:\t{output_str_0!r}"
f"\n{name_1}:\t{output_str_1!r}")
with warnings.catch_warnings():
# This ensures that repeated warnings are shown
# in the output, not just the first occurrence
warnings.simplefilter("always")
warnings.warn(fail_msg, stacklevel=2)
def qwen_prompt(questions: List[str]) -> List[str]:
placeholder = "<|image_pad|>"
return [("<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
f"<|im_start|>user\n<|vision_start|>{placeholder}<|vision_end|>"
f"{q}<|im_end|>\n<|im_start|>assistant\n") for q in questions]
# Map of prompt templates for different models.
PROMPT_TEMPLATES: dict[str, Callable] = {
"qwen2.5vl": qwen_prompt,
}

View File

@@ -26,12 +26,11 @@ from unittest.mock import patch
from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1",
"VLLM_ENABLE_FUSED_EXPERTS_ALLGATHER_EP": "1"
@@ -56,12 +55,10 @@ def test_generate_with_allgather():
vllm_model.generate(example_prompts, sampling_params)
@patch.dict(
os.environ, {
"VLLM_USE_V1": "1",
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
@patch.dict(os.environ, {
"VLLM_WORKER_MULTIPROC_METHOD": "spawn",
"TASK_QUEUE_ENABLE": "1"
})
def test_generate_with_alltoall():
example_prompts = ["Hello, my name is"]
sampling_params = SamplingParams(max_tokens=100, temperature=0.0)
@@ -79,4 +76,4 @@ def test_generate_with_alltoall():
},
"expert_tensor_parallel_size": 1
}) as vllm_model:
vllm_model.generate(example_prompts, sampling_params)
vllm_model.generate(example_prompts, sampling_params)

View File

@@ -1,7 +1,7 @@
import pytest
from modelscope import snapshot_download # type: ignore
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
from tests.e2e.singlecard.test_ilama_lora import (EXPECTED_LORA_OUTPUT,
MODEL_PATH, do_sample)

View File

@@ -27,7 +27,7 @@ from modelscope import snapshot_download # type: ignore
from vllm import SamplingParams
from vllm.model_executor.models.registry import ModelRegistry
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"

View File

@@ -16,7 +16,7 @@
#
import pytest
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
MODELS = [
"Qwen/Qwen3-0.6B",

View File

@@ -2,12 +2,10 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Compare the with and without prefix caching on V1 scheduler or AscendScheduler."""
import os
import pytest
from tests.conftest import VllmRunner
from tests.model_utils import check_outputs_equal
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
# for MHA
@@ -60,8 +58,6 @@ INPUT_PROMPTS = [
]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="mtp is not supported on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None:
@@ -89,8 +85,6 @@ def test_prefix_cache_with_v1_scheduler(model: str, max_tokens: int) -> None:
)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="mtp is not supported on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [50])
def test_prefix_cache_with_ascend_scheduler(model: str,

View File

@@ -22,9 +22,7 @@ Run `pytest tests/multicard/test_torchair_graph_mode.py`.
import os
from typing import Dict
import pytest
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
@@ -78,8 +76,6 @@ def _deepseek_torchair_test_fixture(
print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair():
additional_config = {
"torchair_graph_config": {
@@ -89,8 +85,6 @@ def test_e2e_deepseekv3_with_torchair():
_deepseek_torchair_test_fixture(additional_config)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_deepseekv3_with_torchair_ms_mla():
additional_config = {
"torchair_graph_config": {
@@ -150,8 +144,6 @@ def _pangu_torchair_test_fixture(
print(f"Generated text: {vllm_output[i][1]!r}")
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="torchair graph is not supported on v0")
def test_e2e_pangu_with_torchair():
additional_config = {
"torchair_graph_config": {

View File

@@ -1,15 +1,11 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import gc
import os
import pytest
import torch
from vllm import LLM
if os.getenv("VLLM_USE_V1", "0") != "1":
pytest.skip("Test package requires V1", allow_module_level=True)
MODEL = "Qwen/Qwen2.5-0.5B-Instruct"
PROMPT = "Hello my name is Robert and I"

View File

@@ -9,8 +9,8 @@ Run `pytest tests/e2e/singlecard/core/ascend_scheduler/test_chunk_prefill.py`.
"""
import pytest
from tests.conftest import VllmRunner
from tests.model_utils import check_outputs_equal
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen3-0.6B-Base",

View File

@@ -53,7 +53,6 @@ def model_name():
@pytest.mark.skipif(
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
def test_mtp_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
@@ -62,33 +61,30 @@ def test_mtp_correctness(
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(model=model_name,
trust_remote_code=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_llm = LLM(model=model_name,
trust_remote_code=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm

View File

@@ -60,7 +60,6 @@ def eagle3_model_name():
def test_ngram_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
@@ -70,44 +69,40 @@ def test_ngram_correctness(
should be the same when using ngram speculative decoding.
'''
pytest.skip("Not current support for the test.")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
ref_llm = LLM(model=model_name, max_model_len=1024, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_llm = LLM(
model=model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.7 * len(ref_outputs))
del spec_llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
monkeypatch: pytest.MonkeyPatch,
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
@@ -119,43 +114,40 @@ def test_eagle_correctness(
'''
if not use_eagle3:
pytest.skip("Not current support for the test.")
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
ref_llm = LLM(model=model_name, max_model_len=2048, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
spec_model_name = eagle3_model_name(
) if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
enable_chunked_prefill=True,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
spec_llm = LLM(
model=model_name,
trust_remote_code=True,
enable_chunked_prefill=True,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
enforce_eager=True,
)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm

View File

@@ -20,14 +20,12 @@ Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`.
"""
import os
import pytest
import torch
from vllm import LLM, SamplingParams
from tests.conftest import VllmRunner
from tests.model_utils import check_outputs_equal
from tests.e2e.conftest import VllmRunner
from tests.e2e.model_utils import check_outputs_equal
MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",
@@ -36,37 +34,29 @@ MODELS = [
]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [32])
def test_models(
model: str,
max_tokens: int,
monkeypatch: pytest.MonkeyPatch,
) -> None:
with monkeypatch.context() as m:
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
prompts = [
"Hello, my name is", "The president of the United States is",
"The capital of France is", "The future of AI is"
]
# aclgraph only support on v1
m.setenv("VLLM_USE_V1", "1")
sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0.0)
# TODO: change to use vllmrunner when the registry of custom op is solved
# while running pytest
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
sampling_params = SamplingParams(max_tokens=max_tokens,
temperature=0.0)
# TODO: change to use vllmrunner when the registry of custom op is solved
# while running pytest
vllm_model = LLM(model)
vllm_aclgraph_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model, enforce_eager=True)
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model, enforce_eager=True)
vllm_eager_outputs = vllm_model.generate(prompts, sampling_params)
del vllm_model
torch.npu.empty_cache()
vllm_aclgraph_outputs_list = []
for output in vllm_aclgraph_outputs:
@@ -86,12 +76,9 @@ def test_models(
)
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="aclgraph only support on v1")
def test_deepseek_raises_error(monkeypatch: pytest.MonkeyPatch) -> None:
with monkeypatch.context() as m:
m.setenv("VLLM_USE_MODELSCOPE", "True")
m.setenv("VLLM_USE_V1", "1")
with pytest.raises(NotImplementedError) as excinfo:
VllmRunner("deepseek-ai/DeepSeek-V2-Lite-Chat",
max_model_len=1024,

View File

@@ -21,7 +21,7 @@ import torch
from vllm import LLM, SamplingParams
from vllm.utils import GiB_bytes
from tests.utils import fork_new_process_for_each_test
from tests.e2e.utils import fork_new_process_for_each_test
from vllm_ascend.device_allocator.camem import CaMemAllocator

View File

@@ -20,8 +20,6 @@ Compare the outputs of vLLM with and without aclgraph.
Run `pytest tests/compile/test_aclgraph.py`.
"""
import os
import pytest
import torch
from vllm import LLM, SamplingParams
@@ -29,8 +27,6 @@ from vllm import LLM, SamplingParams
MODELS = ["deepseek-ai/DeepSeek-V2-Lite"]
@pytest.mark.skipif(os.getenv("VLLM_USE_V1") == "0",
reason="new chunked only support on v1")
@pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("max_tokens", [1])
def test_models(
@@ -39,36 +35,33 @@ def test_models(
monkeypatch: pytest.MonkeyPatch,
) -> None:
return
with monkeypatch.context() as m:
prompts = "The president of the United States is"
m.setenv("VLLM_USE_V1", "1")
prompts = "The president of the United States is"
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0.0,
)
sampling_params = SamplingParams(
max_tokens=max_tokens,
temperature=0.0,
)
vllm_model = LLM(model,
long_prefill_token_threshold=4,
enforce_eager=True)
output_chunked = vllm_model.generate(prompts, sampling_params)
logprobs_chunked = output_chunked.outputs[0].logprobs
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model, long_prefill_token_threshold=4, enforce_eager=True)
output_chunked = vllm_model.generate(prompts, sampling_params)
logprobs_chunked = output_chunked.outputs[0].logprobs
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model,
enforce_eager=True,
additional_config={
'ascend_scheduler_config': {
'enabled': True
},
})
output = vllm_model.generate(prompts, sampling_params)
logprobs = output.outputs[0].logprobs
del vllm_model
torch.npu.empty_cache()
vllm_model = LLM(model,
enforce_eager=True,
additional_config={
'ascend_scheduler_config': {
'enabled': True
},
})
output = vllm_model.generate(prompts, sampling_params)
logprobs = output.outputs[0].logprobs
del vllm_model
torch.npu.empty_cache()
logprobs_similarity = torch.cosine_similarity(
logprobs_chunked.flatten(), logprobs.flatten(), dim=0)
assert logprobs_similarity > 0.95
logprobs_similarity = torch.cosine_similarity(logprobs_chunked.flatten(),
logprobs.flatten(),
dim=0)
assert logprobs_similarity > 0.95

View File

@@ -21,8 +21,8 @@ from typing import Optional
from modelscope import snapshot_download # type: ignore[import-untyped]
from tests.conftest import HfRunner
from tests.utils import check_embeddings_close, matryoshka_fy
from tests.e2e.conftest import HfRunner
from tests.e2e.utils import check_embeddings_close, matryoshka_fy
def run_embedding_correctness_test(

View File

@@ -18,14 +18,14 @@
#
import json
import os
import re
import jsonschema
import pytest
import regex as re
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
os.environ["PYTORCH_NPU_ALLOC_CONF"] = "max_split_size_mb:256"
MODEL_NAME = "Qwen/Qwen2.5-0.5B-Instruct"
@@ -85,11 +85,7 @@ def sample_json_schema():
def check_backend(guided_decoding_backend: str):
if guided_decoding_backend not in GuidedDecodingBackendV0 and os.getenv(
"VLLM_USE_V1") == "0":
pytest.skip(f"{guided_decoding_backend} does not support v0, skip it.")
if guided_decoding_backend not in GuidedDecodingBackendV1 and os.getenv(
"VLLM_USE_V1") == "1":
if guided_decoding_backend not in GuidedDecodingBackendV1:
pytest.skip(f"{guided_decoding_backend} does not support v1, skip it.")

View File

@@ -3,7 +3,7 @@ import vllm
from modelscope import snapshot_download # type: ignore
from vllm.lora.request import LoRARequest
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
MODEL_PATH = "vllm-ascend/ilama-3.2-1B"

View File

@@ -30,7 +30,7 @@ from vllm import SamplingParams
from vllm.assets.image import ImageAsset
import vllm_ascend # noqa: F401
from tests.conftest import VllmRunner
from tests.e2e.conftest import VllmRunner
MODELS = [
"Qwen/Qwen2.5-0.5B-Instruct",

106
tests/e2e/utils.py Normal file
View File

@@ -0,0 +1,106 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# This file is a part of the vllm-ascend project.
# Adapted from vllm-project/vllm/tests/utils.py
# Copyright 2023 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import functools
import os
import signal
from collections.abc import Sequence
from typing import Callable
import torch
import torch.nn.functional as F
from typing_extensions import ParamSpec
_P = ParamSpec("_P")
def fork_new_process_for_each_test(
f: Callable[_P, None]) -> Callable[_P, None]:
"""Decorator to fork a new process for each test function.
See https://github.com/vllm-project/vllm/issues/7053 for more details.
"""
@functools.wraps(f)
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None:
# Make the process the leader of its own process group
# to avoid sending SIGTERM to the parent process
os.setpgrp()
from _pytest.outcomes import Skipped
pid = os.fork()
print(f"Fork a new process to run a test {pid}")
if pid == 0:
try:
f(*args, **kwargs)
except Skipped as e:
# convert Skipped to exit code 0
print(str(e))
os._exit(0)
except Exception:
import traceback
traceback.print_exc()
os._exit(1)
else:
os._exit(0)
else:
pgid = os.getpgid(pid)
_pid, _exitcode = os.waitpid(pid, 0)
# ignore SIGTERM signal itself
old_signal_handler = signal.signal(signal.SIGTERM, signal.SIG_IGN)
# kill all child processes
os.killpg(pgid, signal.SIGTERM)
# restore the signal handler
signal.signal(signal.SIGTERM, old_signal_handler)
assert _exitcode == 0, (f"function {f} failed when called with"
f" args {args} and kwargs {kwargs}")
return wrapper
def matryoshka_fy(tensor: torch.Tensor, dimensions: int):
tensor = torch.tensor(tensor)
tensor = tensor[..., :dimensions]
tensor = F.normalize(tensor, p=2, dim=1)
return tensor
def check_embeddings_close(
*,
embeddings_0_lst: Sequence[list[float]],
embeddings_1_lst: Sequence[list[float]],
name_0: str,
name_1: str,
tol: float = 1e-3,
) -> None:
assert len(embeddings_0_lst) == len(embeddings_1_lst)
for prompt_idx, (embeddings_0, embeddings_1) in enumerate(
zip(embeddings_0_lst, embeddings_1_lst)):
assert len(embeddings_0) == len(embeddings_1), (
f"Length mismatch: {len(embeddings_0)} vs. {len(embeddings_1)}")
sim = F.cosine_similarity(torch.tensor(embeddings_0),
torch.tensor(embeddings_1),
dim=0)
fail_msg = (f"Test{prompt_idx}:"
f"\nCosine similarity: \t{sim:.4f}"
f"\n{name_0}:\t{embeddings_0[:16]!r}"
f"\n{name_1}:\t{embeddings_1[:16]!r}")
assert sim >= 1 - tol, fail_msg