# # Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved. # Copyright 2023 The vLLM team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # This file is a part of the vllm-ascend project. # Adapted from vllm-project/vllm/blob/main/tests/conftest.py # import contextlib import gc from typing import List, Optional, Tuple, TypeVar, Union import numpy as np import pytest import torch from huggingface_hub import snapshot_download from PIL import Image from vllm import LLM, SamplingParams from vllm.config import TaskOption from vllm.inputs import ExplicitEncoderDecoderPrompt, TextPrompt, TokensPrompt from vllm.outputs import RequestOutput from vllm.sampling_params import BeamSearchParams from vllm.utils import is_list_of from tests.model_utils import (PROMPT_TEMPLATES, TokensTextLogprobs, TokensTextLogprobsPromptLogprobs) # TODO: remove this part after the patch merged into vllm, if # we not explicitly patch here, some of them might be effectiveless # in pytest scenario from vllm_ascend.utils import adapt_patch # noqa E402 adapt_patch(True) from vllm.distributed.parallel_state import ( # noqa E402 destroy_distributed_environment, destroy_model_parallel) _M = TypeVar("_M") _PromptMultiModalInput = Union[List[_M], List[List[_M]]] PromptImageInput = _PromptMultiModalInput[Image.Image] PromptAudioInput = _PromptMultiModalInput[Tuple[np.ndarray, int]] PromptVideoInput = _PromptMultiModalInput[np.ndarray] def cleanup_dist_env_and_memory(shutdown_ray: bool = False): destroy_model_parallel() destroy_distributed_environment() with contextlib.suppress(AssertionError): torch.distributed.destroy_process_group() if shutdown_ray: import ray # Lazy import Ray ray.shutdown() gc.collect() torch.npu.empty_cache() torch.npu.reset_peak_memory_stats() class VllmRunner: def __init__( self, model_name: str, task: TaskOption = "auto", tokenizer_name: Optional[str] = None, tokenizer_mode: str = "auto", # Use smaller max model length, otherwise bigger model cannot run due # to kv cache size limit. max_model_len: int = 1024, dtype: str = "half", disable_log_stats: bool = True, tensor_parallel_size: int = 1, block_size: int = 16, enable_chunked_prefill: bool = False, swap_space: int = 4, enforce_eager: Optional[bool] = True, quantization: Optional[str] = None, **kwargs, ) -> None: self.model = LLM( model=model_name, task=task, tokenizer=tokenizer_name, tokenizer_mode=tokenizer_mode, trust_remote_code=True, dtype=dtype, swap_space=swap_space, enforce_eager=enforce_eager, disable_log_stats=disable_log_stats, tensor_parallel_size=tensor_parallel_size, max_model_len=max_model_len, block_size=block_size, enable_chunked_prefill=enable_chunked_prefill, quantization=quantization, **kwargs, ) def get_inputs( self, prompts: List[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[TextPrompt]: if images is not None: assert len(prompts) == len(images) if videos is not None: assert len(prompts) == len(videos) if audios is not None: assert len(prompts) == len(audios) inputs = [TextPrompt(prompt=prompt) for prompt in prompts] if images is not None: for i, image in enumerate(images): if image is not None: inputs[i]["multi_modal_data"] = {"image": image} if videos is not None: for i, video in enumerate(videos): if video is not None: inputs[i]["multi_modal_data"] = {"video": video} if audios is not None: for i, audio in enumerate(audios): if audio is not None: inputs[i]["multi_modal_data"] = {"audio": audio} return inputs def generate( self, prompts: List[str], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[List[int]], List[str]]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.generate(inputs, sampling_params=sampling_params) outputs: List[Tuple[List[List[int]], List[str]]] = [] for req_output in req_outputs: prompt_str = req_output.prompt prompt_ids = req_output.prompt_token_ids req_sample_output_ids: List[List[int]] = [] req_sample_output_strs: List[str] = [] for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) req_sample_output_ids.append(prompt_ids + output_ids) req_sample_output_strs.append(prompt_str + output_str) outputs.append((req_sample_output_ids, req_sample_output_strs)) return outputs @staticmethod def _final_steps_generate_w_logprobs( req_outputs: List[RequestOutput], ) -> List[TokensTextLogprobsPromptLogprobs]: outputs: List[TokensTextLogprobsPromptLogprobs] = [] for req_output in req_outputs: assert len(req_output.outputs) > 0 for sample in req_output.outputs: output_str = sample.text output_ids = list(sample.token_ids) output_logprobs = sample.logprobs outputs.append((output_ids, output_str, output_logprobs, req_output.prompt_logprobs)) return outputs def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.generate(inputs, sampling_params=sampling_params) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) # Omit prompt logprobs if not required by sampling params return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] if sampling_params.prompt_logprobs is None else toks_str_logsprobs_prompt_logprobs) def generate_encoder_decoder_w_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], sampling_params: SamplingParams, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: ''' Logprobs generation for vLLM encoder/decoder models ''' assert sampling_params.logprobs is not None req_outputs = self.model.generate(encoder_decoder_prompts, sampling_params=sampling_params) toks_str_logsprobs_prompt_logprobs = ( self._final_steps_generate_w_logprobs(req_outputs)) # Omit prompt logprobs if not required by sampling params return ([x[0:-1] for x in toks_str_logsprobs_prompt_logprobs] if sampling_params.prompt_logprobs is None else toks_str_logsprobs_prompt_logprobs) def generate_greedy( self, prompts: List[str], max_tokens: int, images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[Tuple[List[int], str]]: greedy_params = SamplingParams(temperature=0.0, max_tokens=max_tokens) outputs = self.generate(prompts, greedy_params, images=images, videos=videos, audios=audios) return [(output_ids[0], output_str[0]) for output_ids, output_str in outputs] def generate_greedy_logprobs( self, prompts: List[str], max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, images: Optional[PromptImageInput] = None, audios: Optional[PromptAudioInput] = None, videos: Optional[PromptVideoInput] = None, stop_token_ids: Optional[List[int]] = None, stop: Optional[List[str]] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=num_prompt_logprobs, stop_token_ids=stop_token_ids, stop=stop) return self.generate_w_logprobs(prompts, greedy_logprobs_params, images=images, audios=audios, videos=videos) def generate_encoder_decoder_greedy_logprobs( self, encoder_decoder_prompts: List[ExplicitEncoderDecoderPrompt[str, str]], max_tokens: int, num_logprobs: int, num_prompt_logprobs: Optional[int] = None, ) -> Union[List[TokensTextLogprobs], List[TokensTextLogprobsPromptLogprobs]]: greedy_logprobs_params = SamplingParams( temperature=0.0, max_tokens=max_tokens, logprobs=num_logprobs, prompt_logprobs=(num_prompt_logprobs), ) ''' Greedy logprobs generation for vLLM encoder/decoder models ''' return self.generate_encoder_decoder_w_logprobs( encoder_decoder_prompts, greedy_logprobs_params) def generate_beam_search( self, prompts: Union[List[str], List[List[int]]], beam_width: int, max_tokens: int, ) -> List[Tuple[List[List[int]], List[str]]]: if is_list_of(prompts, str, check="all"): prompts = [TextPrompt(prompt=prompt) for prompt in prompts] else: prompts = [ TokensPrompt(prompt_token_ids=tokens) for tokens in prompts ] outputs = self.model.beam_search( prompts, BeamSearchParams(beam_width=beam_width, max_tokens=max_tokens)) returned_outputs = [] for output in outputs: token_ids = [x.tokens for x in output.sequences] texts = [x.text for x in output.sequences] returned_outputs.append((token_ids, texts)) return returned_outputs def classify(self, prompts: List[str]) -> List[List[float]]: req_outputs = self.model.classify(prompts) return [req_output.outputs.probs for req_output in req_outputs] def encode( self, prompts: List[str], images: Optional[PromptImageInput] = None, videos: Optional[PromptVideoInput] = None, audios: Optional[PromptAudioInput] = None, ) -> List[List[float]]: inputs = self.get_inputs(prompts, images=images, videos=videos, audios=audios) req_outputs = self.model.embed(inputs) return [req_output.outputs.embedding for req_output in req_outputs] def score( self, text_1: Union[str, List[str]], text_2: Union[str, List[str]], ) -> List[float]: req_outputs = self.model.score(text_1, text_2) return [req_output.outputs.score for req_output in req_outputs] def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): del self.model cleanup_dist_env_and_memory() @pytest.fixture(scope="session") def vllm_runner(): return VllmRunner @pytest.fixture(params=list(PROMPT_TEMPLATES.keys())) def prompt_template(request): return PROMPT_TEMPLATES[request.param] @pytest.fixture(scope="session") def ilama_lora_files(): return snapshot_download(repo_id="jeeejeee/ilama-text2sql-spider")