[Core] Init vllm-ascend (#3)
### What this PR does / why we need it?
vLLM Ascend plugin (vllm-ascend) is a backend plugin for running vLLM on
the Ascend NPU.
This plugin is the recommended approach for supporting the Ascend
backend within the vLLM community. It adheres to the principles outlined
in the [RFC]: Hardware pluggable, providing a hardware-pluggable
interface that decouples the integration of the Ascend NPU with vLLM.
This patch also include changes to make CI work and use cache speed up
e2e test, including:
1. Change push (post merge ci) and pull_request (pr ci) trigger branch
to main
2. Make mypy work by ignore base_communicator and clear unused deps
3. Several improvements for vllm_ascend_test:
- use cache (pip, ms, hf) speed up e2e test (25mins --> 5mins)
- switch `git clone` command to `action/checkout` to speedup checkout
and
- Enable sv for pytest for better info dump
- Remove network host to resole `docker: conflicting ontions: cannot
attach both user-defined and non-user-definednetwork-modes`, which is a
problem on docker 1.45 but not on 1.39.
4. Adapt MLA decode optimizations:
cabaf4eff3
### Does this PR introduce _any_ user-facing change?
Yes, init the PR.
### How was this patch tested?
- This is the first PR to make ascend NPU work on vLLM. All code is
tested on ascend with vLLM V0 Engine.
- CI passed
---------
Signed-off-by: wangxiyuan <wangxiyuan1007@gmail.com>
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
Co-authored-by: wangxiyuan <wangxiyuan1007@gmail.com>
Co-authored-by: MengqingCao <cmq0113@163.com>
Co-authored-by: wangshuai09 <391746016@qq.com>
Co-authored-by: Shanshan Shen <467638484@qq.com>
Co-authored-by: wangli <wangli858794774@gmail.com>
This commit is contained in:
331
tests/conftest.py
Normal file
331
tests/conftest.py
Normal file
@@ -0,0 +1,331 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/conftest.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
from typing import List, Optional, Tuple, TypeVar, Union
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from vllm import LLM, SamplingParams
|
||||
from vllm.config import TaskOption
|
||||
from vllm.distributed import cleanup_dist_env_and_memory
|
||||
from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt
|
||||
from vllm.logger import init_logger
|
||||
from vllm.outputs import RequestOutput
|
||||
from vllm.sampling_params import BeamSearchParams
|
||||
from vllm.utils import is_list_of
|
||||
|
||||
from tests.model_utils import (TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
_M = TypeVar("_M")
|
||||
_PromptMultiModalInput = Union[List[_M], List[List[_M]]]
|
||||
|
||||
PromptImageInput = _PromptMultiModalInput[Image.Image]
|
||||
PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]]
|
||||
PromptVideoInput = _PromptMultiModalInput[np.ndarray]
|
||||
|
||||
|
||||
class VllmRunner:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
task: TaskOption = "auto",
|
||||
tokenizer_name: Optional[str] = None,
|
||||
tokenizer_mode: str = "auto",
|
||||
# Use smaller max model length, otherwise bigger model cannot run due
|
||||
# to kv cache size limit.
|
||||
max_model_len: int = 1024,
|
||||
dtype: str = "half",
|
||||
disable_log_stats: bool = True,
|
||||
tensor_parallel_size: int = 1,
|
||||
block_size: int = 16,
|
||||
enable_chunked_prefill: bool = False,
|
||||
swap_space: int = 4,
|
||||
enforce_eager: Optional[bool] = False,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
self.model = LLM(
|
||||
model=model_name,
|
||||
task=task,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode=tokenizer_mode,
|
||||
trust_remote_code=True,
|
||||
dtype=dtype,
|
||||
swap_space=swap_space,
|
||||
enforce_eager=enforce_eager,
|
||||
disable_log_stats=disable_log_stats,
|
||||
tensor_parallel_size=tensor_parallel_size,
|
||||
max_model_len=max_model_len,
|
||||
block_size=block_size,
|
||||
enable_chunked_prefill=enable_chunked_prefill,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def get_inputs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[TextPrompt]:
|
||||
if images is not None:
|
||||
assert len(prompts) == len(images)
|
||||
|
||||
if videos is not None:
|
||||
assert len(prompts) == len(videos)
|
||||
|
||||
if audios is not None:
|
||||
assert len(prompts) == len(audios)
|
||||
|
||||
inputs = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
if images is not None:
|
||||
for i, image in enumerate(images):
|
||||
if image is not None:
|
||||
inputs[i]["multi_modal_data"] = {"image": image}
|
||||
|
||||
if videos is not None:
|
||||
for i, video in enumerate(videos):
|
||||
if video is not None:
|
||||
inputs[i]["multi_modal_data"] = {"video": video}
|
||||
|
||||
if audios is not None:
|
||||
for i, audio in enumerate(audios):
|
||||
if audio is not None:
|
||||
inputs[i]["multi_modal_data"] = {"audio": audio}
|
||||
|
||||
return inputs
|
||||
|
||||
def generate(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
outputs: List[Tuple[List[List[int]], List[str]]] = []
|
||||
for req_output in req_outputs:
|
||||
prompt_str = req_output.prompt
|
||||
prompt_ids = req_output.prompt_token_ids
|
||||
req_sample_output_ids: List[List[int]] = []
|
||||
req_sample_output_strs: List[str] = []
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
req_sample_output_ids.append(prompt_ids + output_ids)
|
||||
req_sample_output_strs.append(prompt_str + output_str)
|
||||
outputs.append((req_sample_output_ids, req_sample_output_strs))
|
||||
return outputs
|
||||
|
||||
@staticmethod
|
||||
def _final_steps_generate_w_logprobs(
|
||||
req_outputs: List[RequestOutput],
|
||||
) -> List[TokensTextLogprobsPromptLogprobs]:
|
||||
outputs: List[TokensTextLogprobsPromptLogprobs] = []
|
||||
for req_output in req_outputs:
|
||||
assert len(req_output.outputs) > 0
|
||||
for sample in req_output.outputs:
|
||||
output_str = sample.text
|
||||
output_ids = list(sample.token_ids)
|
||||
output_logprobs = sample.logprobs
|
||||
outputs.append((output_ids, output_str, output_logprobs,
|
||||
req_output.prompt_logprobs))
|
||||
return outputs
|
||||
|
||||
def generate_w_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
sampling_params: SamplingParams,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.generate(inputs,
|
||||
sampling_params=sampling_params)
|
||||
|
||||
toks_str_logsprobs_prompt_logprobs = (
|
||||
self._final_steps_generate_w_logprobs(req_outputs))
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
|
||||
def generate_encoder_decoder_w_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
sampling_params: SamplingParams,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
'''
|
||||
Logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
assert sampling_params.logprobs is not None
|
||||
req_outputs = self.model.generate(encoder_decoder_prompts,
|
||||
sampling_params=sampling_params)
|
||||
toks_str_logsprobs_prompt_logprobs = (
|
||||
self._final_steps_generate_w_logprobs(req_outputs))
|
||||
# Omit prompt logprobs if not required by sampling params
|
||||
return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs]
|
||||
if sampling_params.prompt_logprobs is None else
|
||||
toks_str_logsprobs_prompt_logprobs)
|
||||
|
||||
def generate_greedy(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[Tuple[List[int], str]]:
|
||||
greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens)
|
||||
outputs = self.generate(prompts,
|
||||
greedy_params,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
return [(output_ids[0], output_str[0])
|
||||
for output_ids, output_str in outputs]
|
||||
|
||||
def generate_greedy_logprobs(
|
||||
self,
|
||||
prompts: List[str],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
images: Optional[PromptImageInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
stop_token_ids: Optional[List[int]] = None,
|
||||
stop: Optional[List[str]] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=num_prompt_logprobs,
|
||||
stop_token_ids=stop_token_ids,
|
||||
stop=stop)
|
||||
|
||||
return self.generate_w_logprobs(prompts,
|
||||
greedy_logprobs_params,
|
||||
images=images,
|
||||
audios=audios,
|
||||
videos=videos)
|
||||
|
||||
def generate_encoder_decoder_greedy_logprobs(
|
||||
self,
|
||||
encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]],
|
||||
max_tokens: int,
|
||||
num_logprobs: int,
|
||||
num_prompt_logprobs: Optional[int] = None,
|
||||
) -> Union[List[TokensTextLogprobs],
|
||||
List[TokensTextLogprobsPromptLogprobs]]:
|
||||
greedy_logprobs_params = SamplingParams(
|
||||
temperature=0.0,
|
||||
max_tokens=max_tokens,
|
||||
logprobs=num_logprobs,
|
||||
prompt_logprobs=(num_prompt_logprobs),
|
||||
)
|
||||
'''
|
||||
Greedy logprobs generation for vLLM encoder/decoder models
|
||||
'''
|
||||
|
||||
return self.generate_encoder_decoder_w_logprobs(
|
||||
encoder_decoder_prompts, greedy_logprobs_params)
|
||||
|
||||
def generate_beam_search(
|
||||
self,
|
||||
prompts: Union[List[str], List[List[int]]],
|
||||
beam_width: int,
|
||||
max_tokens: int,
|
||||
) -> List[Tuple[List[List[int]], List[str]]]:
|
||||
if is_list_of(prompts, str, check="all"):
|
||||
prompts = [TextPrompt(prompt=prompt) for prompt in prompts]
|
||||
else:
|
||||
prompts = [
|
||||
TokensPrompt(prompt_token_ids=tokens) for tokens in prompts
|
||||
]
|
||||
outputs = self.model.beam_search(
|
||||
prompts,
|
||||
BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens))
|
||||
returned_outputs = []
|
||||
for output in outputs:
|
||||
token_ids = [x.tokens for x in output.sequences]
|
||||
texts = [x.text for x in output.sequences]
|
||||
returned_outputs.append((token_ids, texts))
|
||||
return returned_outputs
|
||||
|
||||
def classify(self, prompts: List[str]) -> List[List[float]]:
|
||||
req_outputs = self.model.classify(prompts)
|
||||
return [req_output.outputs.probs for req_output in req_outputs]
|
||||
|
||||
def encode(
|
||||
self,
|
||||
prompts: List[str],
|
||||
images: Optional[PromptImageInput] = None,
|
||||
videos: Optional[PromptVideoInput] = None,
|
||||
audios: Optional[PromptAudioInput] = None,
|
||||
) -> List[List[float]]:
|
||||
inputs = self.get_inputs(prompts,
|
||||
images=images,
|
||||
videos=videos,
|
||||
audios=audios)
|
||||
|
||||
req_outputs = self.model.embed(inputs)
|
||||
return [req_output.outputs.embedding for req_output in req_outputs]
|
||||
|
||||
def score(
|
||||
self,
|
||||
text_1: Union[str, List[str]],
|
||||
text_2: Union[str, List[str]],
|
||||
) -> List[float]:
|
||||
req_outputs = self.model.score(text_1, text_2)
|
||||
return [req_output.outputs.score for req_output in req_outputs]
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_value, traceback):
|
||||
del self.model
|
||||
cleanup_dist_env_and_memory()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def vllm_runner():
|
||||
return VllmRunner
|
||||
303
tests/model_utils.py
Normal file
303
tests/model_utils.py
Normal file
@@ -0,0 +1,303 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm-project/vllm/blob/main/tests/models/utils.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
|
||||
import warnings
|
||||
from typing import Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
from vllm.config import ModelConfig, TaskOption
|
||||
from vllm.inputs import InputContext
|
||||
from vllm.sequence import Logprob, PromptLogprobs, SampleLogprobs
|
||||
|
||||
TokensText = Tuple[List[int], str]
|
||||
|
||||
|
||||
def check_outputs_equal(
|
||||
*,
|
||||
outputs_0_lst: Sequence[TokensText],
|
||||
outputs_1_lst: Sequence[TokensText],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
):
|
||||
"""
|
||||
Compare the two sequences generated by different models,
|
||||
which should be equal.
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
output_ids_0, output_str_0 = outputs_0
|
||||
output_ids_1, output_str_1 = outputs_1
|
||||
|
||||
# The text and token outputs should exactly match
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
assert output_str_0 == output_str_1, fail_msg
|
||||
assert output_ids_0 == output_ids_1, fail_msg
|
||||
|
||||
|
||||
# Representation of generated sequence as a tuple of
|
||||
# * Token ID list
|
||||
# * String
|
||||
# * List of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TokensTextLogprobs = Tuple[List[int], str, Optional[Union[List[Dict[int,
|
||||
float]],
|
||||
SampleLogprobs]]]
|
||||
|
||||
# Allow for tokens to be represented as str's rather than IDs;
|
||||
# tuple of
|
||||
# * Token string representations list
|
||||
# * String
|
||||
# * Optional list of top sample logprobs for each sampled token
|
||||
#
|
||||
# Assumes prompt logprobs were not requested.
|
||||
TextTextLogprobs = Tuple[List[str], str, Optional[Union[List[Dict[str, float]],
|
||||
List[Dict[str,
|
||||
Logprob]]]]]
|
||||
|
||||
# Representation of generated sequence as a tuple of
|
||||
# * Token ID list
|
||||
# * String
|
||||
# * Optional list of top sample logprobs for each sampled token
|
||||
# * Optional list of top prompt logprobs for each prompt token
|
||||
#
|
||||
# Allows prompt logprobs to be requested.
|
||||
TokensTextLogprobsPromptLogprobs = Tuple[
|
||||
List[int], str, Optional[Union[List[Dict[int, float]], SampleLogprobs]],
|
||||
Optional[Union[List[Optional[Dict[int, float]]], PromptLogprobs]]]
|
||||
|
||||
|
||||
def check_logprobs_close(
|
||||
*,
|
||||
outputs_0_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
outputs_1_lst: Sequence[Union[TokensTextLogprobs,
|
||||
TokensTextLogprobsPromptLogprobs,
|
||||
TextTextLogprobs]],
|
||||
name_0: str,
|
||||
name_1: str,
|
||||
num_outputs_0_skip_tokens: int = 0,
|
||||
warn_on_mismatch: bool = True,
|
||||
always_check_logprobs: bool = False,
|
||||
) -> None:
|
||||
"""Compare the logprobs of two sequences generated by different models,
|
||||
which should be similar but not necessarily equal.
|
||||
|
||||
How sample logprobs are compared:
|
||||
* `always_check_logprobs == True`: set of highest-logprob token ids
|
||||
must match between seq0 and seq1 at all sampled token offsets
|
||||
* `always_check_logprobs == False`: highest-logprob token ids are
|
||||
only compared at sampled token offsets for which generated token
|
||||
ids don't match
|
||||
|
||||
Prompt logprobs must be provided either for both input sequences, or
|
||||
for neither. If prompt logprobs are provided, then highest-logprob
|
||||
prompt token ids must match between seq0 and seq1 at all prompt token
|
||||
offsets.
|
||||
|
||||
Args:
|
||||
outputs_0_lst: First sequence to compare
|
||||
outputs_0_lst: Second sequence to compare
|
||||
name_0: sequence #0 name
|
||||
name_1: sequence #1 name
|
||||
num_outputs_0_skip_tokens: If > 0, specifies the number of initial
|
||||
sequence #0 tokens & logprobs to discard
|
||||
before comparison, i.e. all
|
||||
of sequence #1 will be compared to
|
||||
sequence #0 beginning at index
|
||||
num_outputs_0_skip_tokens
|
||||
warn_on_mismatch: Issue a warning if there is token-wise or text-wise
|
||||
mismatch between the two sequences
|
||||
always_check_logprobs: If true, check logprobs even when tokens match
|
||||
"""
|
||||
assert len(outputs_0_lst) == len(outputs_1_lst)
|
||||
|
||||
# Loop through responses to each prompt.
|
||||
for prompt_idx, (outputs_0,
|
||||
outputs_1) in enumerate(zip(outputs_0_lst,
|
||||
outputs_1_lst)):
|
||||
assert len(outputs_0) == len(outputs_1)
|
||||
if len(outputs_0) == 3:
|
||||
assert len(outputs_1) == 3
|
||||
# Break out tokens, text & sample logprobs
|
||||
# (prompt logprobs were not provided)
|
||||
output_ids_0, output_str_0, logprobs_0 = outputs_0
|
||||
output_ids_1, output_str_1, logprobs_1 = outputs_1
|
||||
elif len(outputs_0) == 4:
|
||||
assert len(outputs_1) == 4
|
||||
# Break out tokens, text, sample logprobs & prompt logprobs
|
||||
(
|
||||
output_ids_0,
|
||||
output_str_0,
|
||||
logprobs_0,
|
||||
prompt_logprobs_0,
|
||||
) = outputs_0
|
||||
(
|
||||
output_ids_1,
|
||||
output_str_1,
|
||||
logprobs_1,
|
||||
prompt_logprobs_1,
|
||||
) = outputs_1
|
||||
|
||||
# Test prompt logprobs closeness
|
||||
if (prompt_logprobs_0 is not None
|
||||
and prompt_logprobs_1 is not None):
|
||||
# Both sequences' prompt logprobs lists are not `None``
|
||||
# (although individual list elements may be `None`);
|
||||
# for each token's logprobs:
|
||||
for idx, (logprobs_elem_0, logprobs_elem_1) in enumerate(
|
||||
zip(prompt_logprobs_0, prompt_logprobs_1)):
|
||||
fail_msg = (
|
||||
f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tPrompt index {idx}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\tPrompt index {idx}\t{logprobs_elem_1}")
|
||||
|
||||
if logprobs_elem_0 is None:
|
||||
# If the seq 0 token's logprobs are `None`,
|
||||
# the seq 1 token's logprobs must be `None`
|
||||
assert logprobs_elem_1 is None, fail_msg
|
||||
else:
|
||||
# If the seq 0 token's logprobs are not `None`,
|
||||
# the seq 1 token's logprobs must not be `None`
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
# Logprobs check: top-k token choices must be the same
|
||||
assert (set(logprobs_elem_0.keys()) == set(
|
||||
logprobs_elem_1.keys())), fail_msg
|
||||
else:
|
||||
# Both sequence logprobs lists must be `None`
|
||||
fail_msg = (f"Prompt logprobs test:"
|
||||
f"\n{name_0}:\tlogprobs\t{prompt_logprobs_0}"
|
||||
f"\n{name_1}:\tlogprobs\t{prompt_logprobs_1}")
|
||||
|
||||
assert (prompt_logprobs_0 is None
|
||||
and prompt_logprobs_1 is None), fail_msg
|
||||
else:
|
||||
raise ValueError(f"Outputs tuple must have 3 or 4 elements but "
|
||||
f"{len(outputs_0)} elements were provided: "
|
||||
f"{outputs_0}")
|
||||
|
||||
if logprobs_0 is None:
|
||||
logprobs_0 = [None] * len(output_ids_0)
|
||||
if logprobs_1 is None:
|
||||
logprobs_1 = [None] * len(output_ids_1)
|
||||
|
||||
# Skip specified number of initial sequence #0 tokens
|
||||
# & logprobs, leaving output text as-is for simplicity
|
||||
# (text mismatches may generate warnings but do not
|
||||
# cause the test to fail.)
|
||||
if num_outputs_0_skip_tokens < 0:
|
||||
raise ValueError("num_outputs_0_skip_tokens must be non-negative")
|
||||
output_ids_0 = output_ids_0[num_outputs_0_skip_tokens:]
|
||||
logprobs_0 = logprobs_0[num_outputs_0_skip_tokens:]
|
||||
|
||||
# Loop through generated tokens.
|
||||
for idx, (output_id_0,
|
||||
output_id_1) in enumerate(zip(output_ids_0, output_ids_1)):
|
||||
|
||||
is_tok_mismatch = output_id_0 != output_id_1
|
||||
|
||||
# If generated tokens don't match
|
||||
# or it is desired to always check logprobs,
|
||||
# then
|
||||
if is_tok_mismatch or always_check_logprobs:
|
||||
logprobs_elem_0 = logprobs_0[idx]
|
||||
logprobs_elem_1 = logprobs_1[idx]
|
||||
|
||||
# Each predicted token must be in top N logprobs of the other
|
||||
fail_msg = (
|
||||
f"Test{prompt_idx}:"
|
||||
f"\nMatched tokens:\t{output_ids_0[:idx]}"
|
||||
f"\n{name_0}:\t{output_str_0!r}\t{logprobs_elem_0}"
|
||||
f"\n{name_1}:\t{output_str_1!r}\t{logprobs_elem_1}")
|
||||
|
||||
assert logprobs_elem_0 is not None, fail_msg
|
||||
assert logprobs_elem_1 is not None, fail_msg
|
||||
assert output_id_0 in logprobs_elem_1, fail_msg
|
||||
assert output_id_1 in logprobs_elem_0, fail_msg
|
||||
|
||||
if warn_on_mismatch and is_tok_mismatch:
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
# Break out since sequences will now diverge.
|
||||
break
|
||||
else:
|
||||
if output_str_0 != output_str_1 and warn_on_mismatch:
|
||||
# The token outputs exactly match,
|
||||
# so the text outputs should exactly match as well
|
||||
fail_msg = (f"Test{prompt_idx}:"
|
||||
f"\n{name_0}:\t{output_str_0!r}"
|
||||
f"\n{name_1}:\t{output_str_1!r}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# This ensures that repeated warnings are shown
|
||||
# in the output, not just the first occurrence
|
||||
warnings.simplefilter("always")
|
||||
|
||||
warnings.warn(fail_msg, stacklevel=2)
|
||||
|
||||
|
||||
def build_model_context(model_name: str,
|
||||
task: TaskOption = "auto",
|
||||
tokenizer_name: Optional[str] = None,
|
||||
trust_remote_code: bool = False,
|
||||
dtype: Optional[Union[str, torch.dtype]] = None,
|
||||
mm_processor_kwargs: Optional[Dict] = None,
|
||||
limit_mm_per_prompt: Optional[Dict] = None):
|
||||
"""Creates an InputContext for a given model.
|
||||
|
||||
Args:
|
||||
model_name: Name of the model being considered.
|
||||
tokenizer_name: Name of the tokenizer being considered.
|
||||
trust_remote_code: Whether or not to allow loading remote code.
|
||||
mm_processor_kwargs: optional processor kwargs for to be leveraged
|
||||
in the input processor, mapper, dummy data creation, etc.
|
||||
limit_mm_per_prompt: Multimodal limits.
|
||||
|
||||
Returns:
|
||||
InputContext for the model being considered.
|
||||
"""
|
||||
if tokenizer_name is None:
|
||||
tokenizer_name = model_name
|
||||
if dtype is None:
|
||||
dtype = "half"
|
||||
|
||||
model_config = ModelConfig(
|
||||
model_name,
|
||||
task=task,
|
||||
tokenizer=tokenizer_name,
|
||||
tokenizer_mode="auto",
|
||||
trust_remote_code=trust_remote_code,
|
||||
dtype=dtype,
|
||||
seed=0,
|
||||
mm_processor_kwargs=mm_processor_kwargs,
|
||||
limit_mm_per_prompt=limit_mm_per_prompt,
|
||||
)
|
||||
return InputContext(model_config)
|
||||
61
tests/test_offline_inference.py
Normal file
61
tests/test_offline_inference.py
Normal file
@@ -0,0 +1,61 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# This file is a part of the vllm-ascend project.
|
||||
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
|
||||
# Copyright 2023 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
"""Compare the short outputs of HF and vLLM when using greedy sampling.
|
||||
|
||||
Run `pytest tests/test_offline_inference.py`.
|
||||
"""
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import vllm # noqa: F401
|
||||
from conftest import VllmRunner
|
||||
|
||||
import vllm_ascend # noqa: F401
|
||||
|
||||
MODELS = [
|
||||
"Qwen/Qwen2.5-0.5B-Instruct",
|
||||
]
|
||||
os.environ["VLLM_USE_MODELSCOPE"] = "True"
|
||||
|
||||
TARGET_TEST_SUITE = os.environ.get("TARGET_TEST_SUITE", "L4")
|
||||
|
||||
|
||||
@pytest.mark.parametrize("model", MODELS)
|
||||
@pytest.mark.parametrize("dtype", ["half", "float16"])
|
||||
@pytest.mark.parametrize("max_tokens", [5])
|
||||
def test_models(
|
||||
model: str,
|
||||
dtype: str,
|
||||
max_tokens: int,
|
||||
) -> None:
|
||||
os.environ["VLLM_ATTENTION_BACKEND"] = "ASCEND"
|
||||
|
||||
# 5042 tokens for gemma2
|
||||
# gemma2 has alternating sliding window size of 4096
|
||||
# we need a prompt with more than 4096 tokens to test the sliding window
|
||||
prompt = "The following numbers of the sequence " + ", ".join(
|
||||
str(i) for i in range(1024)) + " are:"
|
||||
example_prompts = [prompt]
|
||||
|
||||
with VllmRunner(model,
|
||||
max_model_len=8192,
|
||||
dtype=dtype,
|
||||
enforce_eager=False,
|
||||
gpu_memory_utilization=0.7) as vllm_model:
|
||||
vllm_model.generate_greedy(example_prompts, max_tokens)
|
||||
Reference in New Issue
Block a user